From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- tensorflow/core/BUILD | 695 +++++++ tensorflow/core/client/tensor_c_api.cc | 370 ++++ tensorflow/core/client/tensor_c_api_test.cc | 94 + tensorflow/core/common_runtime/device.cc | 37 + tensorflow/core/common_runtime/device.h | 128 ++ tensorflow/core/common_runtime/device_factory.cc | 106 + tensorflow/core/common_runtime/device_factory.h | 69 + tensorflow/core/common_runtime/device_mgr.cc | 90 + tensorflow/core/common_runtime/device_mgr.h | 55 + tensorflow/core/common_runtime/device_set.cc | 68 + tensorflow/core/common_runtime/device_set.h | 64 + tensorflow/core/common_runtime/device_set_test.cc | 65 + tensorflow/core/common_runtime/eigen_thread_pool.h | 22 + tensorflow/core/common_runtime/executor.cc | 2118 ++++++++++++++++++++ tensorflow/core/common_runtime/executor.h | 209 ++ tensorflow/core/common_runtime/function.cc | 1335 ++++++++++++ tensorflow/core/common_runtime/function.h | 100 + tensorflow/core/common_runtime/gpu/dma_helper.h | 18 + .../core/common_runtime/gpu/gpu_allocator_retry.cc | 49 + .../core/common_runtime/gpu/gpu_allocator_retry.h | 36 + .../common_runtime/gpu/gpu_allocator_retry_test.cc | 175 ++ .../core/common_runtime/gpu/gpu_bfc_allocator.cc | 397 ++++ .../core/common_runtime/gpu/gpu_bfc_allocator.h | 156 ++ .../common_runtime/gpu/gpu_bfc_allocator_test.cc | 166 ++ .../core/common_runtime/gpu/gpu_debug_allocator.cc | 186 ++ .../core/common_runtime/gpu/gpu_debug_allocator.h | 68 + .../common_runtime/gpu/gpu_debug_allocator_test.cc | 207 ++ tensorflow/core/common_runtime/gpu/gpu_device.cc | 651 ++++++ tensorflow/core/common_runtime/gpu/gpu_device.h | 94 + .../core/common_runtime/gpu/gpu_device_factory.cc | 52 + .../core/common_runtime/gpu/gpu_event_mgr.cc | 132 ++ tensorflow/core/common_runtime/gpu/gpu_event_mgr.h | 118 ++ .../core/common_runtime/gpu/gpu_event_mgr_test.cc | 152 ++ tensorflow/core/common_runtime/gpu/gpu_init.cc | 147 ++ tensorflow/core/common_runtime/gpu/gpu_init.h | 19 + .../common_runtime/gpu/gpu_region_allocator.cc | 371 ++++ .../core/common_runtime/gpu/gpu_region_allocator.h | 146 ++ .../gpu/gpu_region_allocator_test.cc | 71 + .../core/common_runtime/gpu/gpu_stream_util.cc | 97 + .../core/common_runtime/gpu/gpu_stream_util.h | 30 + .../common_runtime/gpu/gpu_stream_util_test.cc | 137 ++ tensorflow/core/common_runtime/gpu/gpu_util.cc | 345 ++++ tensorflow/core/common_runtime/gpu/gpu_util.h | 89 + .../gpu/gpu_util_platform_specific.cc | 24 + .../core/common_runtime/gpu/pool_allocator.cc | 269 +++ .../core/common_runtime/gpu/pool_allocator.h | 202 ++ .../core/common_runtime/gpu/pool_allocator_test.cc | 203 ++ .../core/common_runtime/gpu/process_state.cc | 220 ++ tensorflow/core/common_runtime/gpu/process_state.h | 140 ++ .../core/common_runtime/gpu/visitable_allocator.h | 30 + .../core/common_runtime/gpu_device_context.h | 45 + .../common_runtime/kernel_benchmark_testlib.cc | 160 ++ .../core/common_runtime/kernel_benchmark_testlib.h | 52 + tensorflow/core/common_runtime/local_device.cc | 51 + tensorflow/core/common_runtime/local_device.h | 27 + tensorflow/core/common_runtime/local_session.cc | 500 +++++ tensorflow/core/common_runtime/local_session.h | 109 + .../core/common_runtime/local_session_test.cc | 314 +++ tensorflow/core/common_runtime/rendezvous_mgr.cc | 170 ++ tensorflow/core/common_runtime/rendezvous_mgr.h | 73 + tensorflow/core/common_runtime/session.cc | 51 + tensorflow/core/common_runtime/session_factory.cc | 41 + tensorflow/core/common_runtime/session_factory.h | 25 + tensorflow/core/common_runtime/session_options.cc | 9 + tensorflow/core/common_runtime/session_test.cc | 17 + tensorflow/core/common_runtime/simple_placer.cc | 559 ++++++ tensorflow/core/common_runtime/simple_placer.h | 81 + .../core/common_runtime/simple_placer_test.cc | 863 ++++++++ .../core/common_runtime/threadpool_device.cc | 55 + tensorflow/core/common_runtime/threadpool_device.h | 31 + .../common_runtime/threadpool_device_factory.cc | 31 + tensorflow/core/example/example.proto | 95 + tensorflow/core/example/feature.proto | 82 + .../core/framework/allocation_description.proto | 15 + tensorflow/core/framework/allocator.cc | 25 + tensorflow/core/framework/allocator.h | 132 ++ tensorflow/core/framework/allocator_test.cc | 61 + tensorflow/core/framework/attr_value.proto | 57 + tensorflow/core/framework/attr_value_util.cc | 382 ++++ tensorflow/core/framework/attr_value_util.h | 83 + tensorflow/core/framework/attr_value_util_test.cc | 91 + tensorflow/core/framework/bfloat16.cc | 22 + tensorflow/core/framework/bfloat16.h | 58 + tensorflow/core/framework/bfloat16_test.cc | 69 + tensorflow/core/framework/cancellation.cc | 79 + tensorflow/core/framework/cancellation.h | 121 ++ tensorflow/core/framework/cancellation_test.cc | 102 + tensorflow/core/framework/config.proto | 61 + tensorflow/core/framework/control_flow.h | 43 + tensorflow/core/framework/device_attributes.proto | 35 + tensorflow/core/framework/device_base.cc | 7 + tensorflow/core/framework/device_base.h | 172 ++ tensorflow/core/framework/fake_input.cc | 214 ++ tensorflow/core/framework/fake_input.h | 25 + tensorflow/core/framework/function.cc | 878 ++++++++ tensorflow/core/framework/function.h | 376 ++++ tensorflow/core/framework/function.proto | 68 + tensorflow/core/framework/function_test.cc | 634 ++++++ tensorflow/core/framework/function_testlib.cc | 146 ++ tensorflow/core/framework/function_testlib.h | 53 + tensorflow/core/framework/graph.proto | 103 + tensorflow/core/framework/graph_def_util.cc | 25 + tensorflow/core/framework/graph_def_util.h | 29 + tensorflow/core/framework/kernel_def.proto | 33 + tensorflow/core/framework/kernel_def_builder.cc | 47 + tensorflow/core/framework/kernel_def_builder.h | 77 + .../core/framework/kernel_def_builder_test.cc | 76 + tensorflow/core/framework/lookup_interface.cc | 45 + tensorflow/core/framework/lookup_interface.h | 65 + tensorflow/core/framework/node_def_builder.cc | 194 ++ tensorflow/core/framework/node_def_builder.h | 176 ++ tensorflow/core/framework/node_def_builder_test.cc | 1036 ++++++++++ tensorflow/core/framework/node_def_util.cc | 414 ++++ tensorflow/core/framework/node_def_util.h | 157 ++ tensorflow/core/framework/node_def_util_test.cc | 442 ++++ tensorflow/core/framework/numeric_op.h | 96 + tensorflow/core/framework/numeric_types.h | 15 + tensorflow/core/framework/op.cc | 135 ++ tensorflow/core/framework/op.h | 122 ++ tensorflow/core/framework/op_def.proto | 142 ++ tensorflow/core/framework/op_def_builder.cc | 447 +++++ tensorflow/core/framework/op_def_builder.h | 109 + tensorflow/core/framework/op_def_builder_test.cc | 519 +++++ tensorflow/core/framework/op_def_util.cc | 344 ++++ tensorflow/core/framework/op_def_util.h | 32 + tensorflow/core/framework/op_def_util_test.cc | 330 +++ tensorflow/core/framework/op_gen_lib.cc | 55 + tensorflow/core/framework/op_gen_lib.h | 24 + tensorflow/core/framework/op_kernel.cc | 749 +++++++ tensorflow/core/framework/op_kernel.h | 1250 ++++++++++++ tensorflow/core/framework/op_kernel_test.cc | 803 ++++++++ tensorflow/core/framework/op_segment.cc | 86 + tensorflow/core/framework/op_segment.h | 67 + tensorflow/core/framework/op_segment_test.cc | 142 ++ tensorflow/core/framework/queue_interface.h | 77 + tensorflow/core/framework/reader_interface.h | 66 + tensorflow/core/framework/reader_op_kernel.cc | 39 + tensorflow/core/framework/reader_op_kernel.h | 42 + tensorflow/core/framework/register_types.h | 90 + tensorflow/core/framework/rendezvous.cc | 263 +++ tensorflow/core/framework/rendezvous.h | 102 + tensorflow/core/framework/rendezvous_test.cc | 314 +++ tensorflow/core/framework/resource_mgr.cc | 146 ++ tensorflow/core/framework/resource_mgr.h | 280 +++ tensorflow/core/framework/resource_mgr_test.cc | 173 ++ tensorflow/core/framework/step_stats.proto | 58 + tensorflow/core/framework/summary.proto | 67 + tensorflow/core/framework/tensor.cc | 570 ++++++ tensorflow/core/framework/tensor.proto | 57 + tensorflow/core/framework/tensor_description.proto | 19 + tensorflow/core/framework/tensor_shape.cc | 138 ++ tensorflow/core/framework/tensor_shape.proto | 29 + tensorflow/core/framework/tensor_shape_test.cc | 75 + tensorflow/core/framework/tensor_slice.cc | 226 +++ tensorflow/core/framework/tensor_slice.h | 189 ++ tensorflow/core/framework/tensor_slice.proto | 34 + tensorflow/core/framework/tensor_slice_test.cc | 246 +++ tensorflow/core/framework/tensor_test.cc | 551 +++++ tensorflow/core/framework/tensor_testutil.cc | 43 + tensorflow/core/framework/tensor_testutil.h | 189 ++ tensorflow/core/framework/tensor_types.h | 92 + tensorflow/core/framework/tensor_util.cc | 28 + tensorflow/core/framework/tensor_util.h | 21 + tensorflow/core/framework/tensor_util_test.cc | 124 ++ tensorflow/core/framework/tracking_allocator.cc | 100 + tensorflow/core/framework/tracking_allocator.h | 80 + .../core/framework/tracking_allocator_test.cc | 115 ++ tensorflow/core/framework/type_traits.h | 69 + tensorflow/core/framework/types.cc | 210 ++ tensorflow/core/framework/types.h | 168 ++ tensorflow/core/framework/types.proto | 48 + tensorflow/core/framework/types_test.cc | 117 ++ tensorflow/core/graph/algorithm.cc | 107 + tensorflow/core/graph/algorithm.h | 40 + tensorflow/core/graph/algorithm_test.cc | 103 + tensorflow/core/graph/colors.cc | 25 + tensorflow/core/graph/colors.h | 14 + tensorflow/core/graph/costmodel.cc | 308 +++ tensorflow/core/graph/costmodel.h | 123 ++ tensorflow/core/graph/costutil.cc | 22 + tensorflow/core/graph/costutil.h | 19 + tensorflow/core/graph/default_device.h | 25 + tensorflow/core/graph/dot.cc | 289 +++ tensorflow/core/graph/dot.h | 43 + tensorflow/core/graph/edgeset.cc | 56 + tensorflow/core/graph/edgeset.h | 216 ++ tensorflow/core/graph/edgeset_test.cc | 95 + tensorflow/core/graph/equal_graph_def.cc | 176 ++ tensorflow/core/graph/equal_graph_def.h | 32 + tensorflow/core/graph/equal_graph_def_test.cc | 279 +++ tensorflow/core/graph/graph.cc | 319 +++ tensorflow/core/graph/graph.h | 440 ++++ tensorflow/core/graph/graph_constructor.cc | 385 ++++ tensorflow/core/graph/graph_constructor.h | 43 + tensorflow/core/graph/graph_constructor_test.cc | 190 ++ tensorflow/core/graph/graph_def_builder.cc | 121 ++ tensorflow/core/graph/graph_def_builder.h | 181 ++ tensorflow/core/graph/graph_partition.cc | 1050 ++++++++++ tensorflow/core/graph/graph_partition.h | 77 + tensorflow/core/graph/graph_partition_test.cc | 316 +++ tensorflow/core/graph/graph_test.cc | 252 +++ tensorflow/core/graph/node_builder.cc | 115 ++ tensorflow/core/graph/node_builder.h | 146 ++ tensorflow/core/graph/node_builder_test.cc | 59 + tensorflow/core/graph/optimizer_cse.cc | 220 ++ tensorflow/core/graph/optimizer_cse.h | 19 + tensorflow/core/graph/optimizer_cse_test.cc | 365 ++++ tensorflow/core/graph/subgraph.cc | 258 +++ tensorflow/core/graph/subgraph.h | 49 + tensorflow/core/graph/subgraph_test.cc | 305 +++ tensorflow/core/graph/tensor_id.cc | 41 + tensorflow/core/graph/tensor_id.h | 28 + tensorflow/core/graph/tensor_id_test.cc | 77 + tensorflow/core/graph/testlib.cc | 299 +++ tensorflow/core/graph/testlib.h | 141 ++ tensorflow/core/graph/types.h | 17 + tensorflow/core/kernels/adjust_contrast_op.cc | 121 ++ tensorflow/core/kernels/adjust_contrast_op.h | 64 + .../kernels/adjust_contrast_op_benchmark_test.cc | 43 + .../core/kernels/adjust_contrast_op_gpu.cu.cc | 22 + tensorflow/core/kernels/adjust_contrast_op_test.cc | 88 + tensorflow/core/kernels/aggregate_ops.cc | 238 +++ tensorflow/core/kernels/aggregate_ops.h | 211 ++ tensorflow/core/kernels/aggregate_ops_gpu.cu.cc | 141 ++ tensorflow/core/kernels/argmax_op.cc | 163 ++ tensorflow/core/kernels/argmax_op.h | 55 + tensorflow/core/kernels/argmax_op_gpu.cu.cc | 20 + tensorflow/core/kernels/assign_op.h | 92 + tensorflow/core/kernels/attention_ops.cc | 92 + tensorflow/core/kernels/avgpooling_op.cc | 418 ++++ tensorflow/core/kernels/avgpooling_op.h | 58 + tensorflow/core/kernels/avgpooling_op_gpu.cu.cc | 101 + tensorflow/core/kernels/batch_matmul_op.cc | 260 +++ tensorflow/core/kernels/batch_norm_op.cc | 223 +++ tensorflow/core/kernels/batch_norm_op.h | 133 ++ tensorflow/core/kernels/batch_norm_op_gpu.cu.cc | 17 + tensorflow/core/kernels/bcast_ops.cc | 71 + tensorflow/core/kernels/bias_op.cc | 112 ++ tensorflow/core/kernels/bias_op.h | 41 + tensorflow/core/kernels/bias_op_gpu.cu.cc | 23 + tensorflow/core/kernels/candidate_sampler_ops.cc | 243 +++ tensorflow/core/kernels/cast_op.cc | 233 +++ tensorflow/core/kernels/cast_op.h | 71 + tensorflow/core/kernels/cast_op_gpu.cu.cc | 45 + tensorflow/core/kernels/cast_op_test.cc | 100 + tensorflow/core/kernels/check_numerics_op.cc | 190 ++ .../core/kernels/check_numerics_op_gpu.cu.cc | 62 + tensorflow/core/kernels/cholesky_op.cc | 71 + tensorflow/core/kernels/concat_op.cc | 153 ++ tensorflow/core/kernels/concat_op.h | 27 + tensorflow/core/kernels/concat_op_cpu.cc | 122 ++ tensorflow/core/kernels/concat_op_gpu.cu.cc | 41 + tensorflow/core/kernels/concat_op_test.cc | 240 +++ tensorflow/core/kernels/constant_op.cc | 249 +++ tensorflow/core/kernels/constant_op.h | 25 + tensorflow/core/kernels/constant_op_gpu.cu.cc | 89 + tensorflow/core/kernels/constant_op_test.cc | 43 + tensorflow/core/kernels/control_flow_ops.cc | 359 ++++ tensorflow/core/kernels/control_flow_ops.h | 22 + tensorflow/core/kernels/control_flow_ops_test.cc | 71 + tensorflow/core/kernels/conv_2d.h | 127 ++ tensorflow/core/kernels/conv_grad_ops.cc | 1190 +++++++++++ tensorflow/core/kernels/conv_ops.cc | 373 ++++ tensorflow/core/kernels/conv_ops_gpu.cu.cc | 35 + tensorflow/core/kernels/conv_ops_gpu_2.cu.cc | 16 + tensorflow/core/kernels/conv_ops_gpu_3.cu.cc | 22 + tensorflow/core/kernels/conv_ops_gpu_matmul.cu.cc | 16 + tensorflow/core/kernels/core_ops_test.cc | 990 +++++++++ tensorflow/core/kernels/count_up_to_op.cc | 51 + tensorflow/core/kernels/cwise_op_abs.cc | 23 + tensorflow/core/kernels/cwise_op_add.cc | 21 + tensorflow/core/kernels/cwise_op_ceil.cc | 8 + tensorflow/core/kernels/cwise_op_complex.cc | 10 + tensorflow/core/kernels/cwise_op_conj.cc | 10 + tensorflow/core/kernels/cwise_op_cos.cc | 8 + tensorflow/core/kernels/cwise_op_div.cc | 21 + tensorflow/core/kernels/cwise_op_equal_to.cc | 21 + tensorflow/core/kernels/cwise_op_exp.cc | 8 + tensorflow/core/kernels/cwise_op_floor.cc | 8 + tensorflow/core/kernels/cwise_op_gpu_abs.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_add.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_ceil.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_cos.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_div.cu.cc | 11 + .../core/kernels/cwise_op_gpu_equal_to.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_exp.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_floor.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_greater.cu.cc | 11 + .../core/kernels/cwise_op_gpu_greater_equal.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc | 11 + .../core/kernels/cwise_op_gpu_isfinite.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_isinf.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_isnan.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_less.cu.cc | 11 + .../core/kernels/cwise_op_gpu_less_equal.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_log.cu.cc | 11 + .../core/kernels/cwise_op_gpu_logical_and.cu.cc | 13 + .../core/kernels/cwise_op_gpu_logical_not.cu.cc | 11 + .../core/kernels/cwise_op_gpu_logical_or.cu.cc | 13 + tensorflow/core/kernels/cwise_op_gpu_maximum.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_minimum.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc | 11 + .../core/kernels/cwise_op_gpu_not_equal_to.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_real.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_select.cu.cc | 15 + tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_sign.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_sin.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_square.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc | 11 + tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc | 11 + tensorflow/core/kernels/cwise_op_greater.cc | 21 + tensorflow/core/kernels/cwise_op_greater_equal.cc | 22 + tensorflow/core/kernels/cwise_op_imag.cc | 10 + tensorflow/core/kernels/cwise_op_inverse.cc | 8 + tensorflow/core/kernels/cwise_op_isfinite.cc | 8 + tensorflow/core/kernels/cwise_op_isinf.cc | 8 + tensorflow/core/kernels/cwise_op_isnan.cc | 8 + tensorflow/core/kernels/cwise_op_less.cc | 20 + tensorflow/core/kernels/cwise_op_less_equal.cc | 22 + tensorflow/core/kernels/cwise_op_log.cc | 8 + tensorflow/core/kernels/cwise_op_logical_and.cc | 10 + tensorflow/core/kernels/cwise_op_logical_not.cc | 10 + tensorflow/core/kernels/cwise_op_logical_or.cc | 10 + tensorflow/core/kernels/cwise_op_maximum.cc | 21 + tensorflow/core/kernels/cwise_op_minimum.cc | 21 + tensorflow/core/kernels/cwise_op_mod.cc | 6 + tensorflow/core/kernels/cwise_op_mul.cc | 21 + tensorflow/core/kernels/cwise_op_neg.cc | 9 + tensorflow/core/kernels/cwise_op_not_equal_to.cc | 10 + tensorflow/core/kernels/cwise_op_pow.cc | 9 + tensorflow/core/kernels/cwise_op_real.cc | 10 + tensorflow/core/kernels/cwise_op_rsqrt.cc | 8 + tensorflow/core/kernels/cwise_op_select.cc | 17 + tensorflow/core/kernels/cwise_op_sigmoid.cc | 8 + tensorflow/core/kernels/cwise_op_sign.cc | 19 + tensorflow/core/kernels/cwise_op_sin.cc | 8 + tensorflow/core/kernels/cwise_op_sqrt.cc | 8 + tensorflow/core/kernels/cwise_op_square.cc | 9 + tensorflow/core/kernels/cwise_op_sub.cc | 21 + tensorflow/core/kernels/cwise_op_tanh.cc | 8 + tensorflow/core/kernels/cwise_ops.h | 607 ++++++ tensorflow/core/kernels/cwise_ops_common.cc | 42 + tensorflow/core/kernels/cwise_ops_common.h | 390 ++++ tensorflow/core/kernels/cwise_ops_gpu_common.cu.h | 135 ++ tensorflow/core/kernels/cwise_ops_test.cc | 167 ++ tensorflow/core/kernels/decode_csv_op.cc | 222 ++ tensorflow/core/kernels/decode_jpeg_op.cc | 72 + tensorflow/core/kernels/decode_png_op.cc | 69 + tensorflow/core/kernels/decode_raw_op.cc | 90 + tensorflow/core/kernels/dense_update_ops.cc | 136 ++ tensorflow/core/kernels/dense_update_ops.h | 43 + tensorflow/core/kernels/dense_update_ops_gpu.cu.cc | 22 + tensorflow/core/kernels/determinant_op.cc | 66 + tensorflow/core/kernels/diag_op.cc | 93 + tensorflow/core/kernels/dynamic_partition_op.cc | 154 ++ .../core/kernels/dynamic_partition_op_test.cc | 145 ++ tensorflow/core/kernels/dynamic_stitch_op.cc | 158 ++ tensorflow/core/kernels/dynamic_stitch_op_test.cc | 133 ++ tensorflow/core/kernels/edit_distance_op.cc | 217 ++ tensorflow/core/kernels/encode_jpeg_op.cc | 114 ++ tensorflow/core/kernels/encode_png_op.cc | 52 + tensorflow/core/kernels/example_parsing_ops.cc | 444 ++++ tensorflow/core/kernels/fact_op.cc | 96 + tensorflow/core/kernels/fifo_queue.cc | 518 +++++ tensorflow/core/kernels/fifo_queue.h | 127 ++ tensorflow/core/kernels/fifo_queue_op.cc | 93 + tensorflow/core/kernels/fill_functor.h | 26 + .../core/kernels/fixed_length_record_reader_op.cc | 109 + tensorflow/core/kernels/gather_op.cc | 136 ++ tensorflow/core/kernels/gather_op_test.cc | 213 ++ tensorflow/core/kernels/identity_op.cc | 45 + tensorflow/core/kernels/identity_op.h | 25 + tensorflow/core/kernels/identity_op_test.cc | 56 + tensorflow/core/kernels/identity_reader_op.cc | 57 + tensorflow/core/kernels/in_topk_op.cc | 58 + .../core/kernels/initializable_lookup_table.cc | 41 + .../core/kernels/initializable_lookup_table.h | 103 + tensorflow/core/kernels/io.cc | 270 +++ tensorflow/core/kernels/io.h | 38 + tensorflow/core/kernels/l2loss_op.cc | 69 + tensorflow/core/kernels/l2loss_op.h | 24 + tensorflow/core/kernels/l2loss_op_gpu.cu.cc | 16 + tensorflow/core/kernels/linalg_ops_common.cc | 99 + tensorflow/core/kernels/linalg_ops_common.h | 123 ++ tensorflow/core/kernels/listdiff_op.cc | 75 + tensorflow/core/kernels/logging_ops.cc | 77 + tensorflow/core/kernels/logging_ops_test.cc | 87 + tensorflow/core/kernels/lookup_table_init_op.cc | 116 ++ tensorflow/core/kernels/lookup_table_op.cc | 166 ++ tensorflow/core/kernels/lookup_table_op.h | 80 + tensorflow/core/kernels/lookup_util.cc | 72 + tensorflow/core/kernels/lookup_util.h | 31 + tensorflow/core/kernels/lrn_op.cc | 228 +++ tensorflow/core/kernels/lrn_op_test.cc | 185 ++ tensorflow/core/kernels/matching_files_op.cc | 42 + tensorflow/core/kernels/matmul_op.cc | 214 ++ tensorflow/core/kernels/matmul_op.h | 40 + tensorflow/core/kernels/matmul_op_gpu.cu.cc | 32 + tensorflow/core/kernels/matmul_op_test.cc | 56 + tensorflow/core/kernels/matrix_inverse_op.cc | 64 + tensorflow/core/kernels/maxpooling_op.cc | 554 +++++ tensorflow/core/kernels/maxpooling_op.h | 29 + tensorflow/core/kernels/maxpooling_op_gpu.cu.cc | 261 +++ tensorflow/core/kernels/maxpooling_op_gpu.h | 42 + tensorflow/core/kernels/no_op.cc | 8 + tensorflow/core/kernels/no_op.h | 17 + tensorflow/core/kernels/ops_testutil.cc | 18 + tensorflow/core/kernels/ops_testutil.h | 191 ++ tensorflow/core/kernels/ops_util.cc | 113 ++ tensorflow/core/kernels/ops_util.h | 180 ++ tensorflow/core/kernels/ops_util_test.cc | 265 +++ tensorflow/core/kernels/pack_op.cc | 114 ++ tensorflow/core/kernels/pad_op.cc | 159 ++ tensorflow/core/kernels/pad_op.h | 27 + tensorflow/core/kernels/pad_op_gpu.cu.cc | 26 + tensorflow/core/kernels/pooling_ops_common.cc | 252 +++ tensorflow/core/kernels/pooling_ops_common.h | 264 +++ tensorflow/core/kernels/pooling_ops_common_gpu.h | 39 + tensorflow/core/kernels/queue_base.cc | 153 ++ tensorflow/core/kernels/queue_base.h | 77 + tensorflow/core/kernels/queue_ops.cc | 288 +++ tensorflow/core/kernels/random_crop_op.cc | 103 + tensorflow/core/kernels/random_crop_op_test.cc | 60 + tensorflow/core/kernels/random_op.cc | 276 +++ tensorflow/core/kernels/random_op.h | 16 + tensorflow/core/kernels/random_op_gpu.cu.cc | 152 ++ tensorflow/core/kernels/random_op_test.cc | 99 + tensorflow/core/kernels/random_shuffle_op.cc | 89 + tensorflow/core/kernels/random_shuffle_queue_op.cc | 740 +++++++ tensorflow/core/kernels/range_sampler.cc | 305 +++ tensorflow/core/kernels/range_sampler.h | 237 +++ tensorflow/core/kernels/range_sampler_test.cc | 320 +++ tensorflow/core/kernels/reader_base.cc | 156 ++ tensorflow/core/kernels/reader_base.h | 107 + tensorflow/core/kernels/reader_base.proto | 13 + tensorflow/core/kernels/reader_ops.cc | 132 ++ tensorflow/core/kernels/reduction_ops.h | 66 + tensorflow/core/kernels/reduction_ops_all.cc | 17 + tensorflow/core/kernels/reduction_ops_any.cc | 17 + tensorflow/core/kernels/reduction_ops_common.h | 302 +++ tensorflow/core/kernels/reduction_ops_gpu.cu.cc | 65 + tensorflow/core/kernels/reduction_ops_max.cc | 26 + tensorflow/core/kernels/reduction_ops_mean.cc | 12 + tensorflow/core/kernels/reduction_ops_min.cc | 26 + tensorflow/core/kernels/reduction_ops_prod.cc | 26 + tensorflow/core/kernels/reduction_ops_sum.cc | 37 + tensorflow/core/kernels/reduction_ops_test.cc | 73 + tensorflow/core/kernels/reference_gemm.h | 75 + tensorflow/core/kernels/relu_op.cc | 154 ++ tensorflow/core/kernels/relu_op.h | 79 + tensorflow/core/kernels/relu_op_gpu.cu.cc | 27 + tensorflow/core/kernels/reshape_op.cc | 29 + tensorflow/core/kernels/reshape_op.h | 83 + tensorflow/core/kernels/resize_area_op.cc | 139 ++ tensorflow/core/kernels/resize_bicubic_op.cc | 121 ++ tensorflow/core/kernels/resize_bilinear_op.cc | 109 + tensorflow/core/kernels/resize_bilinear_op_test.cc | 171 ++ .../core/kernels/resize_nearest_neighbor_op.cc | 89 + .../kernels/resize_nearest_neighbor_op_test.cc | 163 ++ tensorflow/core/kernels/restore_op.cc | 65 + tensorflow/core/kernels/restore_op_test.cc | 305 +++ tensorflow/core/kernels/reverse_op.cc | 139 ++ tensorflow/core/kernels/reverse_op.h | 28 + tensorflow/core/kernels/reverse_op_gpu.cu.cc | 33 + tensorflow/core/kernels/reverse_op_test.cc | 101 + tensorflow/core/kernels/reverse_sequence_op.cc | 170 ++ tensorflow/core/kernels/reverse_sequence_op.h | 56 + .../core/kernels/reverse_sequence_op_gpu.cu.cc | 26 + tensorflow/core/kernels/save_op.cc | 81 + tensorflow/core/kernels/save_op_test.cc | 443 ++++ tensorflow/core/kernels/scatter_op.cc | 167 ++ tensorflow/core/kernels/scatter_op_test.cc | 255 +++ tensorflow/core/kernels/segment_reduction_ops.cc | 466 +++++ .../core/kernels/segment_reduction_ops_test.cc | 157 ++ tensorflow/core/kernels/sendrecv_ops.cc | 116 ++ tensorflow/core/kernels/sendrecv_ops.h | 32 + tensorflow/core/kernels/sequence_ops.cc | 123 ++ tensorflow/core/kernels/shape_ops.cc | 261 +++ tensorflow/core/kernels/slice_op.cc | 242 +++ tensorflow/core/kernels/slice_op.h | 25 + tensorflow/core/kernels/slice_op_gpu.cu.cc | 31 + tensorflow/core/kernels/slice_op_test.cc | 73 + tensorflow/core/kernels/softmax_op.cc | 62 + tensorflow/core/kernels/softmax_op.h | 70 + tensorflow/core/kernels/softmax_op_gpu.cu.cc | 31 + tensorflow/core/kernels/softplus_op.cc | 97 + tensorflow/core/kernels/softplus_op.h | 46 + tensorflow/core/kernels/softplus_op_gpu.cu.cc | 25 + tensorflow/core/kernels/sparse_concat_op.cc | 139 ++ tensorflow/core/kernels/sparse_matmul_op.cc | 192 ++ tensorflow/core/kernels/sparse_matmul_op_test.cc | 139 ++ tensorflow/core/kernels/sparse_reorder_op.cc | 71 + tensorflow/core/kernels/sparse_to_dense_op.cc | 129 ++ tensorflow/core/kernels/sparse_to_dense_op_test.cc | 283 +++ tensorflow/core/kernels/split_op.cc | 146 ++ tensorflow/core/kernels/split_op.h | 31 + tensorflow/core/kernels/split_op_cpu.cc | 30 + tensorflow/core/kernels/split_op_gpu.cu.cc | 31 + .../core/kernels/string_to_hash_bucket_op.cc | 47 + tensorflow/core/kernels/string_to_number_op.cc | 71 + tensorflow/core/kernels/summary_image_op.cc | 169 ++ tensorflow/core/kernels/summary_image_op_test.cc | 141 ++ tensorflow/core/kernels/summary_op.cc | 141 ++ tensorflow/core/kernels/summary_op_test.cc | 282 +++ tensorflow/core/kernels/text_line_reader_op.cc | 99 + tensorflow/core/kernels/tf_record_reader_op.cc | 76 + tensorflow/core/kernels/tile_ops.cc | 460 +++++ tensorflow/core/kernels/tile_ops.h | 48 + tensorflow/core/kernels/tile_ops_gpu.cu.cc | 38 + tensorflow/core/kernels/topk_op.cc | 71 + tensorflow/core/kernels/training_ops.cc | 884 ++++++++ tensorflow/core/kernels/training_ops.h | 65 + tensorflow/core/kernels/training_ops_gpu.cu.cc | 127 ++ tensorflow/core/kernels/training_ops_test.cc | 226 +++ tensorflow/core/kernels/transpose_op.cc | 190 ++ tensorflow/core/kernels/transpose_op.h | 19 + tensorflow/core/kernels/transpose_op_functor.h | 28 + tensorflow/core/kernels/transpose_op_gpu.cu.cc | 43 + tensorflow/core/kernels/unique_op.cc | 61 + tensorflow/core/kernels/unique_op_test.cc | 51 + tensorflow/core/kernels/unpack_op.cc | 96 + tensorflow/core/kernels/variable_ops.cc | 37 + tensorflow/core/kernels/variable_ops.h | 146 ++ tensorflow/core/kernels/where_op.cc | 74 + tensorflow/core/kernels/where_op.h | 65 + tensorflow/core/kernels/whole_file_read_ops.cc | 108 + tensorflow/core/kernels/xent_op.cc | 90 + tensorflow/core/kernels/xent_op.h | 102 + tensorflow/core/kernels/xent_op_gpu.cu.cc | 35 + tensorflow/core/kernels/xent_op_test.cc | 46 + tensorflow/core/lib/core/arena.cc | 246 +++ tensorflow/core/lib/core/arena.h | 90 + tensorflow/core/lib/core/arena_test.cc | 92 + tensorflow/core/lib/core/bit_cast_test.cc | 95 + tensorflow/core/lib/core/bits.h | 84 + tensorflow/core/lib/core/blocking_counter.h | 41 + tensorflow/core/lib/core/blocking_counter_test.cc | 36 + tensorflow/core/lib/core/casts.h | 85 + tensorflow/core/lib/core/coding.cc | 164 ++ tensorflow/core/lib/core/coding.h | 55 + tensorflow/core/lib/core/coding_test.cc | 168 ++ tensorflow/core/lib/core/command_line_flags.cc | 94 + tensorflow/core/lib/core/command_line_flags.h | 60 + tensorflow/core/lib/core/error_codes.proto | 145 ++ tensorflow/core/lib/core/errors.h | 131 ++ tensorflow/core/lib/core/notification.h | 42 + tensorflow/core/lib/core/notification_test.cc | 64 + tensorflow/core/lib/core/raw_coding.h | 43 + tensorflow/core/lib/core/refcount.cc | 35 + tensorflow/core/lib/core/refcount.h | 63 + tensorflow/core/lib/core/refcount_test.cc | 92 + tensorflow/core/lib/core/status.cc | 107 + tensorflow/core/lib/core/status_test.cc | 84 + tensorflow/core/lib/core/status_test_util.h | 20 + tensorflow/core/lib/core/stringpiece.cc | 57 + tensorflow/core/lib/core/stringpiece.h | 159 ++ tensorflow/core/lib/core/threadpool.cc | 108 + tensorflow/core/lib/core/threadpool.h | 59 + tensorflow/core/lib/core/threadpool_test.cc | 93 + tensorflow/core/lib/gtl/array_slice.h | 299 +++ tensorflow/core/lib/gtl/array_slice_internal.h | 253 +++ tensorflow/core/lib/gtl/array_slice_test.cc | 646 ++++++ tensorflow/core/lib/gtl/edit_distance.h | 82 + tensorflow/core/lib/gtl/edit_distance_test.cc | 125 ++ tensorflow/core/lib/gtl/inlined_vector.h | 839 ++++++++ tensorflow/core/lib/gtl/inlined_vector_test.cc | 905 +++++++++ tensorflow/core/lib/gtl/int_type.h | 343 ++++ tensorflow/core/lib/gtl/int_type_test.cc | 282 +++ tensorflow/core/lib/gtl/iterator_range.h | 49 + tensorflow/core/lib/gtl/iterator_range_test.cc | 60 + tensorflow/core/lib/gtl/manual_constructor.h | 230 +++ tensorflow/core/lib/gtl/manual_constructor_test.cc | 113 ++ tensorflow/core/lib/gtl/map_util.h | 123 ++ tensorflow/core/lib/gtl/map_util_test.cc | 47 + tensorflow/core/lib/gtl/stl_util.h | 130 ++ tensorflow/core/lib/gtl/top_n.h | 324 +++ tensorflow/core/lib/gtl/top_n_test.cc | 249 +++ tensorflow/core/lib/hash/crc32c.cc | 244 +++ tensorflow/core/lib/hash/crc32c.h | 39 + tensorflow/core/lib/hash/crc32c_test.cc | 51 + tensorflow/core/lib/hash/hash.cc | 113 ++ tensorflow/core/lib/hash/hash.h | 28 + tensorflow/core/lib/hash/hash_test.cc | 64 + tensorflow/core/lib/histogram/histogram.cc | 247 +++ tensorflow/core/lib/histogram/histogram.h | 119 ++ tensorflow/core/lib/histogram/histogram_test.cc | 112 ++ tensorflow/core/lib/io/block.cc | 236 +++ tensorflow/core/lib/io/block.h | 45 + tensorflow/core/lib/io/block_builder.cc | 107 + tensorflow/core/lib/io/block_builder.h | 57 + tensorflow/core/lib/io/format.cc | 148 ++ tensorflow/core/lib/io/format.h | 99 + tensorflow/core/lib/io/inputbuffer.cc | 112 ++ tensorflow/core/lib/io/inputbuffer.h | 62 + tensorflow/core/lib/io/inputbuffer_test.cc | 174 ++ tensorflow/core/lib/io/iterator.cc | 72 + tensorflow/core/lib/io/iterator.h | 93 + tensorflow/core/lib/io/match.cc | 31 + tensorflow/core/lib/io/match.h | 24 + tensorflow/core/lib/io/match_test.cc | 51 + tensorflow/core/lib/io/path.cc | 92 + tensorflow/core/lib/io/path.h | 47 + tensorflow/core/lib/io/path_test.cc | 65 + tensorflow/core/lib/io/record_reader.cc | 80 + tensorflow/core/lib/io/record_reader.h | 36 + tensorflow/core/lib/io/record_writer.cc | 42 + tensorflow/core/lib/io/record_writer.h | 34 + tensorflow/core/lib/io/recordio_test.cc | 245 +++ tensorflow/core/lib/io/table.cc | 169 ++ tensorflow/core/lib/io/table.h | 76 + tensorflow/core/lib/io/table_builder.cc | 263 +++ tensorflow/core/lib/io/table_builder.h | 87 + tensorflow/core/lib/io/table_format.txt | 8 + tensorflow/core/lib/io/table_options.h | 53 + tensorflow/core/lib/io/table_test.cc | 601 ++++++ tensorflow/core/lib/io/two_level_iterator.cc | 148 ++ tensorflow/core/lib/io/two_level_iterator.h | 30 + tensorflow/core/lib/jpeg/jpeg_handle.cc | 162 ++ tensorflow/core/lib/jpeg/jpeg_handle.h | 51 + tensorflow/core/lib/jpeg/jpeg_mem.cc | 557 +++++ tensorflow/core/lib/jpeg/jpeg_mem.h | 130 ++ tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc | 304 +++ tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg | Bin 0 -> 15416 bytes tensorflow/core/lib/jpeg/testdata/corrupt.jpg | Bin 0 -> 1552 bytes tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg | Bin 0 -> 755 bytes tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg | Bin 0 -> 5505 bytes tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg | Bin 0 -> 5092 bytes .../core/lib/jpeg/testdata/jpeg_merge_test1.jpg | Bin 0 -> 3771 bytes .../lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg | Bin 0 -> 5324 bytes tensorflow/core/lib/png/png_io.cc | 385 ++++ tensorflow/core/lib/png/png_io.h | 88 + tensorflow/core/lib/png/testdata/lena_gray.png | Bin 0 -> 1491 bytes tensorflow/core/lib/png/testdata/lena_rgba.png | Bin 0 -> 4032 bytes tensorflow/core/lib/random/distribution_sampler.cc | 80 + tensorflow/core/lib/random/distribution_sampler.h | 79 + .../core/lib/random/distribution_sampler_test.cc | 90 + tensorflow/core/lib/random/exact_uniform_int.h | 68 + tensorflow/core/lib/random/philox_random.h | 232 +++ tensorflow/core/lib/random/philox_random_test.cc | 58 + .../core/lib/random/philox_random_test_utils.h | 36 + tensorflow/core/lib/random/random.cc | 22 + tensorflow/core/lib/random/random.h | 16 + tensorflow/core/lib/random/random_distributions.h | 361 ++++ .../core/lib/random/random_distributions_test.cc | 270 +++ tensorflow/core/lib/random/random_test.cc | 21 + tensorflow/core/lib/random/simple_philox.cc | 24 + tensorflow/core/lib/random/simple_philox.h | 61 + tensorflow/core/lib/random/simple_philox_test.cc | 120 ++ tensorflow/core/lib/random/weighted_picker.cc | 203 ++ tensorflow/core/lib/random/weighted_picker.h | 118 ++ tensorflow/core/lib/random/weighted_picker_test.cc | 254 +++ tensorflow/core/lib/strings/numbers.cc | 260 +++ tensorflow/core/lib/strings/numbers.h | 92 + tensorflow/core/lib/strings/numbers_test.cc | 113 ++ tensorflow/core/lib/strings/ordered_code.cc | 515 +++++ tensorflow/core/lib/strings/ordered_code.h | 77 + tensorflow/core/lib/strings/ordered_code_test.cc | 1183 +++++++++++ tensorflow/core/lib/strings/str_util.cc | 312 +++ tensorflow/core/lib/strings/str_util.h | 149 ++ tensorflow/core/lib/strings/str_util_test.cc | 258 +++ tensorflow/core/lib/strings/strcat.cc | 194 ++ tensorflow/core/lib/strings/strcat.h | 229 +++ tensorflow/core/lib/strings/strcat_test.cc | 324 +++ tensorflow/core/lib/strings/stringprintf.cc | 85 + tensorflow/core/lib/strings/stringprintf.h | 37 + tensorflow/core/lib/strings/stringprintf_test.cc | 113 ++ tensorflow/core/ops/array_ops.cc | 892 +++++++++ tensorflow/core/ops/attention_ops.cc | 54 + tensorflow/core/ops/candidate_sampling_ops.cc | 351 ++++ tensorflow/core/ops/control_flow_ops.cc | 179 ++ tensorflow/core/ops/data_flow_ops.cc | 357 ++++ tensorflow/core/ops/image_ops.cc | 273 +++ tensorflow/core/ops/io_ops.cc | 332 +++ tensorflow/core/ops/linalg_ops.cc | 97 + tensorflow/core/ops/logging_ops.cc | 43 + tensorflow/core/ops/math_ops.cc | 1053 ++++++++++ tensorflow/core/ops/nn_ops.cc | 543 +++++ tensorflow/core/ops/no_op.cc | 10 + tensorflow/core/ops/parsing_ops.cc | 104 + tensorflow/core/ops/random_ops.cc | 108 + tensorflow/core/ops/sendrecv_ops.cc | 99 + tensorflow/core/ops/sparse_ops.cc | 134 ++ tensorflow/core/ops/state_ops.cc | 290 +++ tensorflow/core/ops/string_ops.cc | 21 + tensorflow/core/ops/summary_ops.cc | 115 ++ tensorflow/core/ops/training_ops.cc | 199 ++ tensorflow/core/platform/default/build_config.bzl | 65 + .../core/platform/default/build_config/BUILD | 85 + .../core/platform/default/build_config_root.bzl | 6 + .../core/platform/default/dynamic_annotations.h | 9 + tensorflow/core/platform/default/integral_types.h | 18 + tensorflow/core/platform/default/logging.cc | 125 ++ tensorflow/core/platform/default/logging.h | 258 +++ tensorflow/core/platform/default/mutex.h | 33 + tensorflow/core/platform/default/protobuf.h | 13 + .../core/platform/default/stream_executor_util.h | 19 + tensorflow/core/platform/default/test_benchmark.cc | 162 ++ .../core/platform/default/thread_annotations.h | 185 ++ tensorflow/core/platform/default/tracing.cc | 37 + tensorflow/core/platform/default/tracing_impl.h | 44 + tensorflow/core/platform/env.cc | 129 ++ tensorflow/core/platform/env_test.cc | 31 + tensorflow/core/platform/init_main.h | 16 + tensorflow/core/platform/integral_types_test.cc | 33 + tensorflow/core/platform/logging.h | 12 + tensorflow/core/platform/logging_test.cc | 76 + tensorflow/core/platform/port.h | 228 +++ tensorflow/core/platform/port_test.cc | 48 + tensorflow/core/platform/posix/env.cc | 385 ++++ tensorflow/core/platform/posix/port.cc | 92 + tensorflow/core/platform/protobuf.h | 29 + tensorflow/core/platform/protobuf_util.cc | 17 + tensorflow/core/platform/regexp.h | 33 + tensorflow/core/platform/stream_executor_util.h | 12 + tensorflow/core/platform/tensor_coding.cc | 53 + tensorflow/core/platform/tensor_coding.h | 40 + tensorflow/core/platform/test.cc | 39 + tensorflow/core/platform/test.h | 17 + tensorflow/core/platform/test_benchmark.h | 58 + tensorflow/core/platform/test_main.cc | 31 + tensorflow/core/platform/thread_annotations.h | 14 + tensorflow/core/platform/tracing.cc | 135 ++ tensorflow/core/platform/tracing.h | 205 ++ tensorflow/core/public/README.md | 90 + tensorflow/core/public/env.h | 273 +++ tensorflow/core/public/session.h | 125 ++ tensorflow/core/public/session_options.h | 50 + tensorflow/core/public/status.h | 96 + tensorflow/core/public/tensor.h | 472 +++++ tensorflow/core/public/tensor_c_api.h | 243 +++ tensorflow/core/public/tensor_shape.h | 239 +++ tensorflow/core/public/tensorflow_server.h | 19 + tensorflow/core/user_ops/fact.cc | 29 + tensorflow/core/util/bcast.cc | 120 ++ tensorflow/core/util/bcast.h | 99 + tensorflow/core/util/bcast_test.cc | 226 +++ tensorflow/core/util/device_name_utils.cc | 338 ++++ tensorflow/core/util/device_name_utils.h | 141 ++ tensorflow/core/util/device_name_utils_test.cc | 369 ++++ tensorflow/core/util/event.proto | 29 + tensorflow/core/util/events_writer.cc | 144 ++ tensorflow/core/util/events_writer.h | 77 + tensorflow/core/util/events_writer_test.cc | 198 ++ tensorflow/core/util/guarded_philox_random.cc | 39 + tensorflow/core/util/guarded_philox_random.h | 56 + tensorflow/core/util/padding.cc | 24 + tensorflow/core/util/padding.h | 37 + tensorflow/core/util/port.cc | 13 + tensorflow/core/util/port.h | 11 + tensorflow/core/util/saved_tensor_slice.proto | 76 + tensorflow/core/util/saved_tensor_slice_util.cc | 76 + tensorflow/core/util/saved_tensor_slice_util.h | 110 + .../core/util/saved_tensor_slice_util_test.cc | 32 + tensorflow/core/util/sparse/README.md | 222 ++ tensorflow/core/util/sparse/dim_comparator.h | 60 + tensorflow/core/util/sparse/group_iterator.cc | 49 + tensorflow/core/util/sparse/group_iterator.h | 120 ++ tensorflow/core/util/sparse/sparse_tensor.h | 353 ++++ tensorflow/core/util/sparse/sparse_tensor_test.cc | 467 +++++ tensorflow/core/util/tensor_slice_reader.cc | 230 +++ tensorflow/core/util/tensor_slice_reader.h | 157 ++ tensorflow/core/util/tensor_slice_reader_cache.cc | 94 + tensorflow/core/util/tensor_slice_reader_cache.h | 73 + tensorflow/core/util/tensor_slice_reader_test.cc | 395 ++++ tensorflow/core/util/tensor_slice_set.cc | 148 ++ tensorflow/core/util/tensor_slice_set.h | 73 + tensorflow/core/util/tensor_slice_set_test.cc | 227 +++ tensorflow/core/util/tensor_slice_util.h | 88 + tensorflow/core/util/tensor_slice_util_test.cc | 91 + tensorflow/core/util/tensor_slice_writer.cc | 110 + tensorflow/core/util/tensor_slice_writer.h | 149 ++ tensorflow/core/util/tensor_slice_writer_test.cc | 248 +++ tensorflow/core/util/use_cudnn.cc | 20 + tensorflow/core/util/use_cudnn.h | 12 + tensorflow/core/util/util.cc | 81 + tensorflow/core/util/util.h | 40 + tensorflow/core/util/work_sharder.cc | 57 + tensorflow/core/util/work_sharder.h | 33 + tensorflow/core/util/work_sharder_test.cc | 57 + 788 files changed, 108161 insertions(+) create mode 100644 tensorflow/core/BUILD create mode 100644 tensorflow/core/client/tensor_c_api.cc create mode 100644 tensorflow/core/client/tensor_c_api_test.cc create mode 100644 tensorflow/core/common_runtime/device.cc create mode 100644 tensorflow/core/common_runtime/device.h create mode 100644 tensorflow/core/common_runtime/device_factory.cc create mode 100644 tensorflow/core/common_runtime/device_factory.h create mode 100644 tensorflow/core/common_runtime/device_mgr.cc create mode 100644 tensorflow/core/common_runtime/device_mgr.h create mode 100644 tensorflow/core/common_runtime/device_set.cc create mode 100644 tensorflow/core/common_runtime/device_set.h create mode 100644 tensorflow/core/common_runtime/device_set_test.cc create mode 100644 tensorflow/core/common_runtime/eigen_thread_pool.h create mode 100644 tensorflow/core/common_runtime/executor.cc create mode 100644 tensorflow/core/common_runtime/executor.h create mode 100644 tensorflow/core/common_runtime/function.cc create mode 100644 tensorflow/core/common_runtime/function.h create mode 100644 tensorflow/core/common_runtime/gpu/dma_helper.h create mode 100644 tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h create mode 100644 tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h create mode 100644 tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h create mode 100644 tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_device.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_device.h create mode 100644 tensorflow/core/common_runtime/gpu/gpu_device_factory.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_event_mgr.h create mode 100644 tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_init.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_init.h create mode 100644 tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_region_allocator.h create mode 100644 tensorflow/core/common_runtime/gpu/gpu_region_allocator_test.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_stream_util.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_stream_util.h create mode 100644 tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_util.cc create mode 100644 tensorflow/core/common_runtime/gpu/gpu_util.h create mode 100644 tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc create mode 100644 tensorflow/core/common_runtime/gpu/pool_allocator.cc create mode 100644 tensorflow/core/common_runtime/gpu/pool_allocator.h create mode 100644 tensorflow/core/common_runtime/gpu/pool_allocator_test.cc create mode 100644 tensorflow/core/common_runtime/gpu/process_state.cc create mode 100644 tensorflow/core/common_runtime/gpu/process_state.h create mode 100644 tensorflow/core/common_runtime/gpu/visitable_allocator.h create mode 100644 tensorflow/core/common_runtime/gpu_device_context.h create mode 100644 tensorflow/core/common_runtime/kernel_benchmark_testlib.cc create mode 100644 tensorflow/core/common_runtime/kernel_benchmark_testlib.h create mode 100644 tensorflow/core/common_runtime/local_device.cc create mode 100644 tensorflow/core/common_runtime/local_device.h create mode 100644 tensorflow/core/common_runtime/local_session.cc create mode 100644 tensorflow/core/common_runtime/local_session.h create mode 100644 tensorflow/core/common_runtime/local_session_test.cc create mode 100644 tensorflow/core/common_runtime/rendezvous_mgr.cc create mode 100644 tensorflow/core/common_runtime/rendezvous_mgr.h create mode 100644 tensorflow/core/common_runtime/session.cc create mode 100644 tensorflow/core/common_runtime/session_factory.cc create mode 100644 tensorflow/core/common_runtime/session_factory.h create mode 100644 tensorflow/core/common_runtime/session_options.cc create mode 100644 tensorflow/core/common_runtime/session_test.cc create mode 100644 tensorflow/core/common_runtime/simple_placer.cc create mode 100644 tensorflow/core/common_runtime/simple_placer.h create mode 100644 tensorflow/core/common_runtime/simple_placer_test.cc create mode 100644 tensorflow/core/common_runtime/threadpool_device.cc create mode 100644 tensorflow/core/common_runtime/threadpool_device.h create mode 100644 tensorflow/core/common_runtime/threadpool_device_factory.cc create mode 100644 tensorflow/core/example/example.proto create mode 100644 tensorflow/core/example/feature.proto create mode 100644 tensorflow/core/framework/allocation_description.proto create mode 100644 tensorflow/core/framework/allocator.cc create mode 100644 tensorflow/core/framework/allocator.h create mode 100644 tensorflow/core/framework/allocator_test.cc create mode 100644 tensorflow/core/framework/attr_value.proto create mode 100644 tensorflow/core/framework/attr_value_util.cc create mode 100644 tensorflow/core/framework/attr_value_util.h create mode 100644 tensorflow/core/framework/attr_value_util_test.cc create mode 100644 tensorflow/core/framework/bfloat16.cc create mode 100644 tensorflow/core/framework/bfloat16.h create mode 100644 tensorflow/core/framework/bfloat16_test.cc create mode 100644 tensorflow/core/framework/cancellation.cc create mode 100644 tensorflow/core/framework/cancellation.h create mode 100644 tensorflow/core/framework/cancellation_test.cc create mode 100644 tensorflow/core/framework/config.proto create mode 100644 tensorflow/core/framework/control_flow.h create mode 100644 tensorflow/core/framework/device_attributes.proto create mode 100644 tensorflow/core/framework/device_base.cc create mode 100644 tensorflow/core/framework/device_base.h create mode 100644 tensorflow/core/framework/fake_input.cc create mode 100644 tensorflow/core/framework/fake_input.h create mode 100644 tensorflow/core/framework/function.cc create mode 100644 tensorflow/core/framework/function.h create mode 100644 tensorflow/core/framework/function.proto create mode 100644 tensorflow/core/framework/function_test.cc create mode 100644 tensorflow/core/framework/function_testlib.cc create mode 100644 tensorflow/core/framework/function_testlib.h create mode 100644 tensorflow/core/framework/graph.proto create mode 100644 tensorflow/core/framework/graph_def_util.cc create mode 100644 tensorflow/core/framework/graph_def_util.h create mode 100644 tensorflow/core/framework/kernel_def.proto create mode 100644 tensorflow/core/framework/kernel_def_builder.cc create mode 100644 tensorflow/core/framework/kernel_def_builder.h create mode 100644 tensorflow/core/framework/kernel_def_builder_test.cc create mode 100644 tensorflow/core/framework/lookup_interface.cc create mode 100644 tensorflow/core/framework/lookup_interface.h create mode 100644 tensorflow/core/framework/node_def_builder.cc create mode 100644 tensorflow/core/framework/node_def_builder.h create mode 100644 tensorflow/core/framework/node_def_builder_test.cc create mode 100644 tensorflow/core/framework/node_def_util.cc create mode 100644 tensorflow/core/framework/node_def_util.h create mode 100644 tensorflow/core/framework/node_def_util_test.cc create mode 100644 tensorflow/core/framework/numeric_op.h create mode 100644 tensorflow/core/framework/numeric_types.h create mode 100644 tensorflow/core/framework/op.cc create mode 100644 tensorflow/core/framework/op.h create mode 100644 tensorflow/core/framework/op_def.proto create mode 100644 tensorflow/core/framework/op_def_builder.cc create mode 100644 tensorflow/core/framework/op_def_builder.h create mode 100644 tensorflow/core/framework/op_def_builder_test.cc create mode 100644 tensorflow/core/framework/op_def_util.cc create mode 100644 tensorflow/core/framework/op_def_util.h create mode 100644 tensorflow/core/framework/op_def_util_test.cc create mode 100644 tensorflow/core/framework/op_gen_lib.cc create mode 100644 tensorflow/core/framework/op_gen_lib.h create mode 100644 tensorflow/core/framework/op_kernel.cc create mode 100644 tensorflow/core/framework/op_kernel.h create mode 100644 tensorflow/core/framework/op_kernel_test.cc create mode 100644 tensorflow/core/framework/op_segment.cc create mode 100644 tensorflow/core/framework/op_segment.h create mode 100644 tensorflow/core/framework/op_segment_test.cc create mode 100644 tensorflow/core/framework/queue_interface.h create mode 100644 tensorflow/core/framework/reader_interface.h create mode 100644 tensorflow/core/framework/reader_op_kernel.cc create mode 100644 tensorflow/core/framework/reader_op_kernel.h create mode 100644 tensorflow/core/framework/register_types.h create mode 100644 tensorflow/core/framework/rendezvous.cc create mode 100644 tensorflow/core/framework/rendezvous.h create mode 100644 tensorflow/core/framework/rendezvous_test.cc create mode 100644 tensorflow/core/framework/resource_mgr.cc create mode 100644 tensorflow/core/framework/resource_mgr.h create mode 100644 tensorflow/core/framework/resource_mgr_test.cc create mode 100644 tensorflow/core/framework/step_stats.proto create mode 100644 tensorflow/core/framework/summary.proto create mode 100644 tensorflow/core/framework/tensor.cc create mode 100644 tensorflow/core/framework/tensor.proto create mode 100644 tensorflow/core/framework/tensor_description.proto create mode 100644 tensorflow/core/framework/tensor_shape.cc create mode 100644 tensorflow/core/framework/tensor_shape.proto create mode 100644 tensorflow/core/framework/tensor_shape_test.cc create mode 100644 tensorflow/core/framework/tensor_slice.cc create mode 100644 tensorflow/core/framework/tensor_slice.h create mode 100644 tensorflow/core/framework/tensor_slice.proto create mode 100644 tensorflow/core/framework/tensor_slice_test.cc create mode 100644 tensorflow/core/framework/tensor_test.cc create mode 100644 tensorflow/core/framework/tensor_testutil.cc create mode 100644 tensorflow/core/framework/tensor_testutil.h create mode 100644 tensorflow/core/framework/tensor_types.h create mode 100644 tensorflow/core/framework/tensor_util.cc create mode 100644 tensorflow/core/framework/tensor_util.h create mode 100644 tensorflow/core/framework/tensor_util_test.cc create mode 100644 tensorflow/core/framework/tracking_allocator.cc create mode 100644 tensorflow/core/framework/tracking_allocator.h create mode 100644 tensorflow/core/framework/tracking_allocator_test.cc create mode 100644 tensorflow/core/framework/type_traits.h create mode 100644 tensorflow/core/framework/types.cc create mode 100644 tensorflow/core/framework/types.h create mode 100644 tensorflow/core/framework/types.proto create mode 100644 tensorflow/core/framework/types_test.cc create mode 100644 tensorflow/core/graph/algorithm.cc create mode 100644 tensorflow/core/graph/algorithm.h create mode 100644 tensorflow/core/graph/algorithm_test.cc create mode 100644 tensorflow/core/graph/colors.cc create mode 100644 tensorflow/core/graph/colors.h create mode 100644 tensorflow/core/graph/costmodel.cc create mode 100644 tensorflow/core/graph/costmodel.h create mode 100644 tensorflow/core/graph/costutil.cc create mode 100644 tensorflow/core/graph/costutil.h create mode 100644 tensorflow/core/graph/default_device.h create mode 100644 tensorflow/core/graph/dot.cc create mode 100644 tensorflow/core/graph/dot.h create mode 100644 tensorflow/core/graph/edgeset.cc create mode 100644 tensorflow/core/graph/edgeset.h create mode 100644 tensorflow/core/graph/edgeset_test.cc create mode 100644 tensorflow/core/graph/equal_graph_def.cc create mode 100644 tensorflow/core/graph/equal_graph_def.h create mode 100644 tensorflow/core/graph/equal_graph_def_test.cc create mode 100644 tensorflow/core/graph/graph.cc create mode 100644 tensorflow/core/graph/graph.h create mode 100644 tensorflow/core/graph/graph_constructor.cc create mode 100644 tensorflow/core/graph/graph_constructor.h create mode 100644 tensorflow/core/graph/graph_constructor_test.cc create mode 100644 tensorflow/core/graph/graph_def_builder.cc create mode 100644 tensorflow/core/graph/graph_def_builder.h create mode 100644 tensorflow/core/graph/graph_partition.cc create mode 100644 tensorflow/core/graph/graph_partition.h create mode 100644 tensorflow/core/graph/graph_partition_test.cc create mode 100644 tensorflow/core/graph/graph_test.cc create mode 100644 tensorflow/core/graph/node_builder.cc create mode 100644 tensorflow/core/graph/node_builder.h create mode 100644 tensorflow/core/graph/node_builder_test.cc create mode 100644 tensorflow/core/graph/optimizer_cse.cc create mode 100644 tensorflow/core/graph/optimizer_cse.h create mode 100644 tensorflow/core/graph/optimizer_cse_test.cc create mode 100644 tensorflow/core/graph/subgraph.cc create mode 100644 tensorflow/core/graph/subgraph.h create mode 100644 tensorflow/core/graph/subgraph_test.cc create mode 100644 tensorflow/core/graph/tensor_id.cc create mode 100644 tensorflow/core/graph/tensor_id.h create mode 100644 tensorflow/core/graph/tensor_id_test.cc create mode 100644 tensorflow/core/graph/testlib.cc create mode 100644 tensorflow/core/graph/testlib.h create mode 100644 tensorflow/core/graph/types.h create mode 100644 tensorflow/core/kernels/adjust_contrast_op.cc create mode 100644 tensorflow/core/kernels/adjust_contrast_op.h create mode 100644 tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc create mode 100644 tensorflow/core/kernels/adjust_contrast_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/adjust_contrast_op_test.cc create mode 100644 tensorflow/core/kernels/aggregate_ops.cc create mode 100644 tensorflow/core/kernels/aggregate_ops.h create mode 100644 tensorflow/core/kernels/aggregate_ops_gpu.cu.cc create mode 100644 tensorflow/core/kernels/argmax_op.cc create mode 100644 tensorflow/core/kernels/argmax_op.h create mode 100644 tensorflow/core/kernels/argmax_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/assign_op.h create mode 100644 tensorflow/core/kernels/attention_ops.cc create mode 100644 tensorflow/core/kernels/avgpooling_op.cc create mode 100644 tensorflow/core/kernels/avgpooling_op.h create mode 100644 tensorflow/core/kernels/avgpooling_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/batch_matmul_op.cc create mode 100644 tensorflow/core/kernels/batch_norm_op.cc create mode 100644 tensorflow/core/kernels/batch_norm_op.h create mode 100644 tensorflow/core/kernels/batch_norm_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/bcast_ops.cc create mode 100644 tensorflow/core/kernels/bias_op.cc create mode 100644 tensorflow/core/kernels/bias_op.h create mode 100644 tensorflow/core/kernels/bias_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/candidate_sampler_ops.cc create mode 100644 tensorflow/core/kernels/cast_op.cc create mode 100644 tensorflow/core/kernels/cast_op.h create mode 100644 tensorflow/core/kernels/cast_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/cast_op_test.cc create mode 100644 tensorflow/core/kernels/check_numerics_op.cc create mode 100644 tensorflow/core/kernels/check_numerics_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/cholesky_op.cc create mode 100644 tensorflow/core/kernels/concat_op.cc create mode 100644 tensorflow/core/kernels/concat_op.h create mode 100644 tensorflow/core/kernels/concat_op_cpu.cc create mode 100644 tensorflow/core/kernels/concat_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/concat_op_test.cc create mode 100644 tensorflow/core/kernels/constant_op.cc create mode 100644 tensorflow/core/kernels/constant_op.h create mode 100644 tensorflow/core/kernels/constant_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/constant_op_test.cc create mode 100644 tensorflow/core/kernels/control_flow_ops.cc create mode 100644 tensorflow/core/kernels/control_flow_ops.h create mode 100644 tensorflow/core/kernels/control_flow_ops_test.cc create mode 100644 tensorflow/core/kernels/conv_2d.h create mode 100644 tensorflow/core/kernels/conv_grad_ops.cc create mode 100644 tensorflow/core/kernels/conv_ops.cc create mode 100644 tensorflow/core/kernels/conv_ops_gpu.cu.cc create mode 100644 tensorflow/core/kernels/conv_ops_gpu_2.cu.cc create mode 100644 tensorflow/core/kernels/conv_ops_gpu_3.cu.cc create mode 100644 tensorflow/core/kernels/conv_ops_gpu_matmul.cu.cc create mode 100644 tensorflow/core/kernels/core_ops_test.cc create mode 100644 tensorflow/core/kernels/count_up_to_op.cc create mode 100644 tensorflow/core/kernels/cwise_op_abs.cc create mode 100644 tensorflow/core/kernels/cwise_op_add.cc create mode 100644 tensorflow/core/kernels/cwise_op_ceil.cc create mode 100644 tensorflow/core/kernels/cwise_op_complex.cc create mode 100644 tensorflow/core/kernels/cwise_op_conj.cc create mode 100644 tensorflow/core/kernels/cwise_op_cos.cc create mode 100644 tensorflow/core/kernels/cwise_op_div.cc create mode 100644 tensorflow/core/kernels/cwise_op_equal_to.cc create mode 100644 tensorflow/core/kernels/cwise_op_exp.cc create mode 100644 tensorflow/core/kernels/cwise_op_floor.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_abs.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_add.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_ceil.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_cos.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_div.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_exp.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_floor.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_greater.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_greater_equal.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_isfinite.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_isinf.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_isnan.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_less.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_less_equal.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_log.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_logical_and.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_logical_not.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_logical_or.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_maximum.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_minimum.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_real.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_select.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_sign.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_sin.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_square.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc create mode 100644 tensorflow/core/kernels/cwise_op_greater.cc create mode 100644 tensorflow/core/kernels/cwise_op_greater_equal.cc create mode 100644 tensorflow/core/kernels/cwise_op_imag.cc create mode 100644 tensorflow/core/kernels/cwise_op_inverse.cc create mode 100644 tensorflow/core/kernels/cwise_op_isfinite.cc create mode 100644 tensorflow/core/kernels/cwise_op_isinf.cc create mode 100644 tensorflow/core/kernels/cwise_op_isnan.cc create mode 100644 tensorflow/core/kernels/cwise_op_less.cc create mode 100644 tensorflow/core/kernels/cwise_op_less_equal.cc create mode 100644 tensorflow/core/kernels/cwise_op_log.cc create mode 100644 tensorflow/core/kernels/cwise_op_logical_and.cc create mode 100644 tensorflow/core/kernels/cwise_op_logical_not.cc create mode 100644 tensorflow/core/kernels/cwise_op_logical_or.cc create mode 100644 tensorflow/core/kernels/cwise_op_maximum.cc create mode 100644 tensorflow/core/kernels/cwise_op_minimum.cc create mode 100644 tensorflow/core/kernels/cwise_op_mod.cc create mode 100644 tensorflow/core/kernels/cwise_op_mul.cc create mode 100644 tensorflow/core/kernels/cwise_op_neg.cc create mode 100644 tensorflow/core/kernels/cwise_op_not_equal_to.cc create mode 100644 tensorflow/core/kernels/cwise_op_pow.cc create mode 100644 tensorflow/core/kernels/cwise_op_real.cc create mode 100644 tensorflow/core/kernels/cwise_op_rsqrt.cc create mode 100644 tensorflow/core/kernels/cwise_op_select.cc create mode 100644 tensorflow/core/kernels/cwise_op_sigmoid.cc create mode 100644 tensorflow/core/kernels/cwise_op_sign.cc create mode 100644 tensorflow/core/kernels/cwise_op_sin.cc create mode 100644 tensorflow/core/kernels/cwise_op_sqrt.cc create mode 100644 tensorflow/core/kernels/cwise_op_square.cc create mode 100644 tensorflow/core/kernels/cwise_op_sub.cc create mode 100644 tensorflow/core/kernels/cwise_op_tanh.cc create mode 100644 tensorflow/core/kernels/cwise_ops.h create mode 100644 tensorflow/core/kernels/cwise_ops_common.cc create mode 100644 tensorflow/core/kernels/cwise_ops_common.h create mode 100644 tensorflow/core/kernels/cwise_ops_gpu_common.cu.h create mode 100644 tensorflow/core/kernels/cwise_ops_test.cc create mode 100644 tensorflow/core/kernels/decode_csv_op.cc create mode 100644 tensorflow/core/kernels/decode_jpeg_op.cc create mode 100644 tensorflow/core/kernels/decode_png_op.cc create mode 100644 tensorflow/core/kernels/decode_raw_op.cc create mode 100644 tensorflow/core/kernels/dense_update_ops.cc create mode 100644 tensorflow/core/kernels/dense_update_ops.h create mode 100644 tensorflow/core/kernels/dense_update_ops_gpu.cu.cc create mode 100644 tensorflow/core/kernels/determinant_op.cc create mode 100644 tensorflow/core/kernels/diag_op.cc create mode 100644 tensorflow/core/kernels/dynamic_partition_op.cc create mode 100644 tensorflow/core/kernels/dynamic_partition_op_test.cc create mode 100644 tensorflow/core/kernels/dynamic_stitch_op.cc create mode 100644 tensorflow/core/kernels/dynamic_stitch_op_test.cc create mode 100644 tensorflow/core/kernels/edit_distance_op.cc create mode 100644 tensorflow/core/kernels/encode_jpeg_op.cc create mode 100644 tensorflow/core/kernels/encode_png_op.cc create mode 100644 tensorflow/core/kernels/example_parsing_ops.cc create mode 100644 tensorflow/core/kernels/fact_op.cc create mode 100644 tensorflow/core/kernels/fifo_queue.cc create mode 100644 tensorflow/core/kernels/fifo_queue.h create mode 100644 tensorflow/core/kernels/fifo_queue_op.cc create mode 100644 tensorflow/core/kernels/fill_functor.h create mode 100644 tensorflow/core/kernels/fixed_length_record_reader_op.cc create mode 100644 tensorflow/core/kernels/gather_op.cc create mode 100644 tensorflow/core/kernels/gather_op_test.cc create mode 100644 tensorflow/core/kernels/identity_op.cc create mode 100644 tensorflow/core/kernels/identity_op.h create mode 100644 tensorflow/core/kernels/identity_op_test.cc create mode 100644 tensorflow/core/kernels/identity_reader_op.cc create mode 100644 tensorflow/core/kernels/in_topk_op.cc create mode 100644 tensorflow/core/kernels/initializable_lookup_table.cc create mode 100644 tensorflow/core/kernels/initializable_lookup_table.h create mode 100644 tensorflow/core/kernels/io.cc create mode 100644 tensorflow/core/kernels/io.h create mode 100644 tensorflow/core/kernels/l2loss_op.cc create mode 100644 tensorflow/core/kernels/l2loss_op.h create mode 100644 tensorflow/core/kernels/l2loss_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/linalg_ops_common.cc create mode 100644 tensorflow/core/kernels/linalg_ops_common.h create mode 100644 tensorflow/core/kernels/listdiff_op.cc create mode 100644 tensorflow/core/kernels/logging_ops.cc create mode 100644 tensorflow/core/kernels/logging_ops_test.cc create mode 100644 tensorflow/core/kernels/lookup_table_init_op.cc create mode 100644 tensorflow/core/kernels/lookup_table_op.cc create mode 100644 tensorflow/core/kernels/lookup_table_op.h create mode 100644 tensorflow/core/kernels/lookup_util.cc create mode 100644 tensorflow/core/kernels/lookup_util.h create mode 100644 tensorflow/core/kernels/lrn_op.cc create mode 100644 tensorflow/core/kernels/lrn_op_test.cc create mode 100644 tensorflow/core/kernels/matching_files_op.cc create mode 100644 tensorflow/core/kernels/matmul_op.cc create mode 100644 tensorflow/core/kernels/matmul_op.h create mode 100644 tensorflow/core/kernels/matmul_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/matmul_op_test.cc create mode 100644 tensorflow/core/kernels/matrix_inverse_op.cc create mode 100644 tensorflow/core/kernels/maxpooling_op.cc create mode 100644 tensorflow/core/kernels/maxpooling_op.h create mode 100644 tensorflow/core/kernels/maxpooling_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/maxpooling_op_gpu.h create mode 100644 tensorflow/core/kernels/no_op.cc create mode 100644 tensorflow/core/kernels/no_op.h create mode 100644 tensorflow/core/kernels/ops_testutil.cc create mode 100644 tensorflow/core/kernels/ops_testutil.h create mode 100644 tensorflow/core/kernels/ops_util.cc create mode 100644 tensorflow/core/kernels/ops_util.h create mode 100644 tensorflow/core/kernels/ops_util_test.cc create mode 100644 tensorflow/core/kernels/pack_op.cc create mode 100644 tensorflow/core/kernels/pad_op.cc create mode 100644 tensorflow/core/kernels/pad_op.h create mode 100644 tensorflow/core/kernels/pad_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/pooling_ops_common.cc create mode 100644 tensorflow/core/kernels/pooling_ops_common.h create mode 100644 tensorflow/core/kernels/pooling_ops_common_gpu.h create mode 100644 tensorflow/core/kernels/queue_base.cc create mode 100644 tensorflow/core/kernels/queue_base.h create mode 100644 tensorflow/core/kernels/queue_ops.cc create mode 100644 tensorflow/core/kernels/random_crop_op.cc create mode 100644 tensorflow/core/kernels/random_crop_op_test.cc create mode 100644 tensorflow/core/kernels/random_op.cc create mode 100644 tensorflow/core/kernels/random_op.h create mode 100644 tensorflow/core/kernels/random_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/random_op_test.cc create mode 100644 tensorflow/core/kernels/random_shuffle_op.cc create mode 100644 tensorflow/core/kernels/random_shuffle_queue_op.cc create mode 100644 tensorflow/core/kernels/range_sampler.cc create mode 100644 tensorflow/core/kernels/range_sampler.h create mode 100644 tensorflow/core/kernels/range_sampler_test.cc create mode 100644 tensorflow/core/kernels/reader_base.cc create mode 100644 tensorflow/core/kernels/reader_base.h create mode 100644 tensorflow/core/kernels/reader_base.proto create mode 100644 tensorflow/core/kernels/reader_ops.cc create mode 100644 tensorflow/core/kernels/reduction_ops.h create mode 100644 tensorflow/core/kernels/reduction_ops_all.cc create mode 100644 tensorflow/core/kernels/reduction_ops_any.cc create mode 100644 tensorflow/core/kernels/reduction_ops_common.h create mode 100644 tensorflow/core/kernels/reduction_ops_gpu.cu.cc create mode 100644 tensorflow/core/kernels/reduction_ops_max.cc create mode 100644 tensorflow/core/kernels/reduction_ops_mean.cc create mode 100644 tensorflow/core/kernels/reduction_ops_min.cc create mode 100644 tensorflow/core/kernels/reduction_ops_prod.cc create mode 100644 tensorflow/core/kernels/reduction_ops_sum.cc create mode 100644 tensorflow/core/kernels/reduction_ops_test.cc create mode 100644 tensorflow/core/kernels/reference_gemm.h create mode 100644 tensorflow/core/kernels/relu_op.cc create mode 100644 tensorflow/core/kernels/relu_op.h create mode 100644 tensorflow/core/kernels/relu_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/reshape_op.cc create mode 100644 tensorflow/core/kernels/reshape_op.h create mode 100644 tensorflow/core/kernels/resize_area_op.cc create mode 100644 tensorflow/core/kernels/resize_bicubic_op.cc create mode 100644 tensorflow/core/kernels/resize_bilinear_op.cc create mode 100644 tensorflow/core/kernels/resize_bilinear_op_test.cc create mode 100644 tensorflow/core/kernels/resize_nearest_neighbor_op.cc create mode 100644 tensorflow/core/kernels/resize_nearest_neighbor_op_test.cc create mode 100644 tensorflow/core/kernels/restore_op.cc create mode 100644 tensorflow/core/kernels/restore_op_test.cc create mode 100644 tensorflow/core/kernels/reverse_op.cc create mode 100644 tensorflow/core/kernels/reverse_op.h create mode 100644 tensorflow/core/kernels/reverse_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/reverse_op_test.cc create mode 100644 tensorflow/core/kernels/reverse_sequence_op.cc create mode 100644 tensorflow/core/kernels/reverse_sequence_op.h create mode 100644 tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/save_op.cc create mode 100644 tensorflow/core/kernels/save_op_test.cc create mode 100644 tensorflow/core/kernels/scatter_op.cc create mode 100644 tensorflow/core/kernels/scatter_op_test.cc create mode 100644 tensorflow/core/kernels/segment_reduction_ops.cc create mode 100644 tensorflow/core/kernels/segment_reduction_ops_test.cc create mode 100644 tensorflow/core/kernels/sendrecv_ops.cc create mode 100644 tensorflow/core/kernels/sendrecv_ops.h create mode 100644 tensorflow/core/kernels/sequence_ops.cc create mode 100644 tensorflow/core/kernels/shape_ops.cc create mode 100644 tensorflow/core/kernels/slice_op.cc create mode 100644 tensorflow/core/kernels/slice_op.h create mode 100644 tensorflow/core/kernels/slice_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/slice_op_test.cc create mode 100644 tensorflow/core/kernels/softmax_op.cc create mode 100644 tensorflow/core/kernels/softmax_op.h create mode 100644 tensorflow/core/kernels/softmax_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/softplus_op.cc create mode 100644 tensorflow/core/kernels/softplus_op.h create mode 100644 tensorflow/core/kernels/softplus_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/sparse_concat_op.cc create mode 100644 tensorflow/core/kernels/sparse_matmul_op.cc create mode 100644 tensorflow/core/kernels/sparse_matmul_op_test.cc create mode 100644 tensorflow/core/kernels/sparse_reorder_op.cc create mode 100644 tensorflow/core/kernels/sparse_to_dense_op.cc create mode 100644 tensorflow/core/kernels/sparse_to_dense_op_test.cc create mode 100644 tensorflow/core/kernels/split_op.cc create mode 100644 tensorflow/core/kernels/split_op.h create mode 100644 tensorflow/core/kernels/split_op_cpu.cc create mode 100644 tensorflow/core/kernels/split_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/string_to_hash_bucket_op.cc create mode 100644 tensorflow/core/kernels/string_to_number_op.cc create mode 100644 tensorflow/core/kernels/summary_image_op.cc create mode 100644 tensorflow/core/kernels/summary_image_op_test.cc create mode 100644 tensorflow/core/kernels/summary_op.cc create mode 100644 tensorflow/core/kernels/summary_op_test.cc create mode 100644 tensorflow/core/kernels/text_line_reader_op.cc create mode 100644 tensorflow/core/kernels/tf_record_reader_op.cc create mode 100644 tensorflow/core/kernels/tile_ops.cc create mode 100644 tensorflow/core/kernels/tile_ops.h create mode 100644 tensorflow/core/kernels/tile_ops_gpu.cu.cc create mode 100644 tensorflow/core/kernels/topk_op.cc create mode 100644 tensorflow/core/kernels/training_ops.cc create mode 100644 tensorflow/core/kernels/training_ops.h create mode 100644 tensorflow/core/kernels/training_ops_gpu.cu.cc create mode 100644 tensorflow/core/kernels/training_ops_test.cc create mode 100644 tensorflow/core/kernels/transpose_op.cc create mode 100644 tensorflow/core/kernels/transpose_op.h create mode 100644 tensorflow/core/kernels/transpose_op_functor.h create mode 100644 tensorflow/core/kernels/transpose_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/unique_op.cc create mode 100644 tensorflow/core/kernels/unique_op_test.cc create mode 100644 tensorflow/core/kernels/unpack_op.cc create mode 100644 tensorflow/core/kernels/variable_ops.cc create mode 100644 tensorflow/core/kernels/variable_ops.h create mode 100644 tensorflow/core/kernels/where_op.cc create mode 100644 tensorflow/core/kernels/where_op.h create mode 100644 tensorflow/core/kernels/whole_file_read_ops.cc create mode 100644 tensorflow/core/kernels/xent_op.cc create mode 100644 tensorflow/core/kernels/xent_op.h create mode 100644 tensorflow/core/kernels/xent_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/xent_op_test.cc create mode 100644 tensorflow/core/lib/core/arena.cc create mode 100644 tensorflow/core/lib/core/arena.h create mode 100644 tensorflow/core/lib/core/arena_test.cc create mode 100644 tensorflow/core/lib/core/bit_cast_test.cc create mode 100644 tensorflow/core/lib/core/bits.h create mode 100644 tensorflow/core/lib/core/blocking_counter.h create mode 100644 tensorflow/core/lib/core/blocking_counter_test.cc create mode 100644 tensorflow/core/lib/core/casts.h create mode 100644 tensorflow/core/lib/core/coding.cc create mode 100644 tensorflow/core/lib/core/coding.h create mode 100644 tensorflow/core/lib/core/coding_test.cc create mode 100644 tensorflow/core/lib/core/command_line_flags.cc create mode 100644 tensorflow/core/lib/core/command_line_flags.h create mode 100644 tensorflow/core/lib/core/error_codes.proto create mode 100644 tensorflow/core/lib/core/errors.h create mode 100644 tensorflow/core/lib/core/notification.h create mode 100644 tensorflow/core/lib/core/notification_test.cc create mode 100644 tensorflow/core/lib/core/raw_coding.h create mode 100644 tensorflow/core/lib/core/refcount.cc create mode 100644 tensorflow/core/lib/core/refcount.h create mode 100644 tensorflow/core/lib/core/refcount_test.cc create mode 100644 tensorflow/core/lib/core/status.cc create mode 100644 tensorflow/core/lib/core/status_test.cc create mode 100644 tensorflow/core/lib/core/status_test_util.h create mode 100644 tensorflow/core/lib/core/stringpiece.cc create mode 100644 tensorflow/core/lib/core/stringpiece.h create mode 100644 tensorflow/core/lib/core/threadpool.cc create mode 100644 tensorflow/core/lib/core/threadpool.h create mode 100644 tensorflow/core/lib/core/threadpool_test.cc create mode 100644 tensorflow/core/lib/gtl/array_slice.h create mode 100644 tensorflow/core/lib/gtl/array_slice_internal.h create mode 100644 tensorflow/core/lib/gtl/array_slice_test.cc create mode 100644 tensorflow/core/lib/gtl/edit_distance.h create mode 100644 tensorflow/core/lib/gtl/edit_distance_test.cc create mode 100644 tensorflow/core/lib/gtl/inlined_vector.h create mode 100644 tensorflow/core/lib/gtl/inlined_vector_test.cc create mode 100644 tensorflow/core/lib/gtl/int_type.h create mode 100644 tensorflow/core/lib/gtl/int_type_test.cc create mode 100644 tensorflow/core/lib/gtl/iterator_range.h create mode 100644 tensorflow/core/lib/gtl/iterator_range_test.cc create mode 100644 tensorflow/core/lib/gtl/manual_constructor.h create mode 100644 tensorflow/core/lib/gtl/manual_constructor_test.cc create mode 100644 tensorflow/core/lib/gtl/map_util.h create mode 100644 tensorflow/core/lib/gtl/map_util_test.cc create mode 100644 tensorflow/core/lib/gtl/stl_util.h create mode 100644 tensorflow/core/lib/gtl/top_n.h create mode 100644 tensorflow/core/lib/gtl/top_n_test.cc create mode 100644 tensorflow/core/lib/hash/crc32c.cc create mode 100644 tensorflow/core/lib/hash/crc32c.h create mode 100644 tensorflow/core/lib/hash/crc32c_test.cc create mode 100644 tensorflow/core/lib/hash/hash.cc create mode 100644 tensorflow/core/lib/hash/hash.h create mode 100644 tensorflow/core/lib/hash/hash_test.cc create mode 100644 tensorflow/core/lib/histogram/histogram.cc create mode 100644 tensorflow/core/lib/histogram/histogram.h create mode 100644 tensorflow/core/lib/histogram/histogram_test.cc create mode 100644 tensorflow/core/lib/io/block.cc create mode 100644 tensorflow/core/lib/io/block.h create mode 100644 tensorflow/core/lib/io/block_builder.cc create mode 100644 tensorflow/core/lib/io/block_builder.h create mode 100644 tensorflow/core/lib/io/format.cc create mode 100644 tensorflow/core/lib/io/format.h create mode 100644 tensorflow/core/lib/io/inputbuffer.cc create mode 100644 tensorflow/core/lib/io/inputbuffer.h create mode 100644 tensorflow/core/lib/io/inputbuffer_test.cc create mode 100644 tensorflow/core/lib/io/iterator.cc create mode 100644 tensorflow/core/lib/io/iterator.h create mode 100644 tensorflow/core/lib/io/match.cc create mode 100644 tensorflow/core/lib/io/match.h create mode 100644 tensorflow/core/lib/io/match_test.cc create mode 100644 tensorflow/core/lib/io/path.cc create mode 100644 tensorflow/core/lib/io/path.h create mode 100644 tensorflow/core/lib/io/path_test.cc create mode 100644 tensorflow/core/lib/io/record_reader.cc create mode 100644 tensorflow/core/lib/io/record_reader.h create mode 100644 tensorflow/core/lib/io/record_writer.cc create mode 100644 tensorflow/core/lib/io/record_writer.h create mode 100644 tensorflow/core/lib/io/recordio_test.cc create mode 100644 tensorflow/core/lib/io/table.cc create mode 100644 tensorflow/core/lib/io/table.h create mode 100644 tensorflow/core/lib/io/table_builder.cc create mode 100644 tensorflow/core/lib/io/table_builder.h create mode 100644 tensorflow/core/lib/io/table_format.txt create mode 100644 tensorflow/core/lib/io/table_options.h create mode 100644 tensorflow/core/lib/io/table_test.cc create mode 100644 tensorflow/core/lib/io/two_level_iterator.cc create mode 100644 tensorflow/core/lib/io/two_level_iterator.h create mode 100644 tensorflow/core/lib/jpeg/jpeg_handle.cc create mode 100644 tensorflow/core/lib/jpeg/jpeg_handle.h create mode 100644 tensorflow/core/lib/jpeg/jpeg_mem.cc create mode 100644 tensorflow/core/lib/jpeg/jpeg_mem.h create mode 100644 tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc create mode 100644 tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg create mode 100644 tensorflow/core/lib/jpeg/testdata/corrupt.jpg create mode 100644 tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg create mode 100644 tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg create mode 100644 tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg create mode 100644 tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg create mode 100644 tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg create mode 100644 tensorflow/core/lib/png/png_io.cc create mode 100644 tensorflow/core/lib/png/png_io.h create mode 100644 tensorflow/core/lib/png/testdata/lena_gray.png create mode 100644 tensorflow/core/lib/png/testdata/lena_rgba.png create mode 100644 tensorflow/core/lib/random/distribution_sampler.cc create mode 100644 tensorflow/core/lib/random/distribution_sampler.h create mode 100644 tensorflow/core/lib/random/distribution_sampler_test.cc create mode 100644 tensorflow/core/lib/random/exact_uniform_int.h create mode 100644 tensorflow/core/lib/random/philox_random.h create mode 100644 tensorflow/core/lib/random/philox_random_test.cc create mode 100644 tensorflow/core/lib/random/philox_random_test_utils.h create mode 100644 tensorflow/core/lib/random/random.cc create mode 100644 tensorflow/core/lib/random/random.h create mode 100644 tensorflow/core/lib/random/random_distributions.h create mode 100644 tensorflow/core/lib/random/random_distributions_test.cc create mode 100644 tensorflow/core/lib/random/random_test.cc create mode 100644 tensorflow/core/lib/random/simple_philox.cc create mode 100644 tensorflow/core/lib/random/simple_philox.h create mode 100644 tensorflow/core/lib/random/simple_philox_test.cc create mode 100644 tensorflow/core/lib/random/weighted_picker.cc create mode 100644 tensorflow/core/lib/random/weighted_picker.h create mode 100644 tensorflow/core/lib/random/weighted_picker_test.cc create mode 100644 tensorflow/core/lib/strings/numbers.cc create mode 100644 tensorflow/core/lib/strings/numbers.h create mode 100644 tensorflow/core/lib/strings/numbers_test.cc create mode 100644 tensorflow/core/lib/strings/ordered_code.cc create mode 100644 tensorflow/core/lib/strings/ordered_code.h create mode 100644 tensorflow/core/lib/strings/ordered_code_test.cc create mode 100644 tensorflow/core/lib/strings/str_util.cc create mode 100644 tensorflow/core/lib/strings/str_util.h create mode 100644 tensorflow/core/lib/strings/str_util_test.cc create mode 100644 tensorflow/core/lib/strings/strcat.cc create mode 100644 tensorflow/core/lib/strings/strcat.h create mode 100644 tensorflow/core/lib/strings/strcat_test.cc create mode 100644 tensorflow/core/lib/strings/stringprintf.cc create mode 100644 tensorflow/core/lib/strings/stringprintf.h create mode 100644 tensorflow/core/lib/strings/stringprintf_test.cc create mode 100644 tensorflow/core/ops/array_ops.cc create mode 100644 tensorflow/core/ops/attention_ops.cc create mode 100644 tensorflow/core/ops/candidate_sampling_ops.cc create mode 100644 tensorflow/core/ops/control_flow_ops.cc create mode 100644 tensorflow/core/ops/data_flow_ops.cc create mode 100644 tensorflow/core/ops/image_ops.cc create mode 100644 tensorflow/core/ops/io_ops.cc create mode 100644 tensorflow/core/ops/linalg_ops.cc create mode 100644 tensorflow/core/ops/logging_ops.cc create mode 100644 tensorflow/core/ops/math_ops.cc create mode 100644 tensorflow/core/ops/nn_ops.cc create mode 100644 tensorflow/core/ops/no_op.cc create mode 100644 tensorflow/core/ops/parsing_ops.cc create mode 100644 tensorflow/core/ops/random_ops.cc create mode 100644 tensorflow/core/ops/sendrecv_ops.cc create mode 100644 tensorflow/core/ops/sparse_ops.cc create mode 100644 tensorflow/core/ops/state_ops.cc create mode 100644 tensorflow/core/ops/string_ops.cc create mode 100644 tensorflow/core/ops/summary_ops.cc create mode 100644 tensorflow/core/ops/training_ops.cc create mode 100644 tensorflow/core/platform/default/build_config.bzl create mode 100644 tensorflow/core/platform/default/build_config/BUILD create mode 100644 tensorflow/core/platform/default/build_config_root.bzl create mode 100644 tensorflow/core/platform/default/dynamic_annotations.h create mode 100644 tensorflow/core/platform/default/integral_types.h create mode 100644 tensorflow/core/platform/default/logging.cc create mode 100644 tensorflow/core/platform/default/logging.h create mode 100644 tensorflow/core/platform/default/mutex.h create mode 100644 tensorflow/core/platform/default/protobuf.h create mode 100644 tensorflow/core/platform/default/stream_executor_util.h create mode 100644 tensorflow/core/platform/default/test_benchmark.cc create mode 100644 tensorflow/core/platform/default/thread_annotations.h create mode 100644 tensorflow/core/platform/default/tracing.cc create mode 100644 tensorflow/core/platform/default/tracing_impl.h create mode 100644 tensorflow/core/platform/env.cc create mode 100644 tensorflow/core/platform/env_test.cc create mode 100644 tensorflow/core/platform/init_main.h create mode 100644 tensorflow/core/platform/integral_types_test.cc create mode 100644 tensorflow/core/platform/logging.h create mode 100644 tensorflow/core/platform/logging_test.cc create mode 100644 tensorflow/core/platform/port.h create mode 100644 tensorflow/core/platform/port_test.cc create mode 100644 tensorflow/core/platform/posix/env.cc create mode 100644 tensorflow/core/platform/posix/port.cc create mode 100644 tensorflow/core/platform/protobuf.h create mode 100644 tensorflow/core/platform/protobuf_util.cc create mode 100644 tensorflow/core/platform/regexp.h create mode 100644 tensorflow/core/platform/stream_executor_util.h create mode 100644 tensorflow/core/platform/tensor_coding.cc create mode 100644 tensorflow/core/platform/tensor_coding.h create mode 100644 tensorflow/core/platform/test.cc create mode 100644 tensorflow/core/platform/test.h create mode 100644 tensorflow/core/platform/test_benchmark.h create mode 100644 tensorflow/core/platform/test_main.cc create mode 100644 tensorflow/core/platform/thread_annotations.h create mode 100644 tensorflow/core/platform/tracing.cc create mode 100644 tensorflow/core/platform/tracing.h create mode 100644 tensorflow/core/public/README.md create mode 100644 tensorflow/core/public/env.h create mode 100644 tensorflow/core/public/session.h create mode 100644 tensorflow/core/public/session_options.h create mode 100644 tensorflow/core/public/status.h create mode 100644 tensorflow/core/public/tensor.h create mode 100644 tensorflow/core/public/tensor_c_api.h create mode 100644 tensorflow/core/public/tensor_shape.h create mode 100644 tensorflow/core/public/tensorflow_server.h create mode 100644 tensorflow/core/user_ops/fact.cc create mode 100644 tensorflow/core/util/bcast.cc create mode 100644 tensorflow/core/util/bcast.h create mode 100644 tensorflow/core/util/bcast_test.cc create mode 100644 tensorflow/core/util/device_name_utils.cc create mode 100644 tensorflow/core/util/device_name_utils.h create mode 100644 tensorflow/core/util/device_name_utils_test.cc create mode 100644 tensorflow/core/util/event.proto create mode 100644 tensorflow/core/util/events_writer.cc create mode 100644 tensorflow/core/util/events_writer.h create mode 100644 tensorflow/core/util/events_writer_test.cc create mode 100644 tensorflow/core/util/guarded_philox_random.cc create mode 100644 tensorflow/core/util/guarded_philox_random.h create mode 100644 tensorflow/core/util/padding.cc create mode 100644 tensorflow/core/util/padding.h create mode 100644 tensorflow/core/util/port.cc create mode 100644 tensorflow/core/util/port.h create mode 100644 tensorflow/core/util/saved_tensor_slice.proto create mode 100644 tensorflow/core/util/saved_tensor_slice_util.cc create mode 100644 tensorflow/core/util/saved_tensor_slice_util.h create mode 100644 tensorflow/core/util/saved_tensor_slice_util_test.cc create mode 100644 tensorflow/core/util/sparse/README.md create mode 100644 tensorflow/core/util/sparse/dim_comparator.h create mode 100644 tensorflow/core/util/sparse/group_iterator.cc create mode 100644 tensorflow/core/util/sparse/group_iterator.h create mode 100644 tensorflow/core/util/sparse/sparse_tensor.h create mode 100644 tensorflow/core/util/sparse/sparse_tensor_test.cc create mode 100644 tensorflow/core/util/tensor_slice_reader.cc create mode 100644 tensorflow/core/util/tensor_slice_reader.h create mode 100644 tensorflow/core/util/tensor_slice_reader_cache.cc create mode 100644 tensorflow/core/util/tensor_slice_reader_cache.h create mode 100644 tensorflow/core/util/tensor_slice_reader_test.cc create mode 100644 tensorflow/core/util/tensor_slice_set.cc create mode 100644 tensorflow/core/util/tensor_slice_set.h create mode 100644 tensorflow/core/util/tensor_slice_set_test.cc create mode 100644 tensorflow/core/util/tensor_slice_util.h create mode 100644 tensorflow/core/util/tensor_slice_util_test.cc create mode 100644 tensorflow/core/util/tensor_slice_writer.cc create mode 100644 tensorflow/core/util/tensor_slice_writer.h create mode 100644 tensorflow/core/util/tensor_slice_writer_test.cc create mode 100644 tensorflow/core/util/use_cudnn.cc create mode 100644 tensorflow/core/util/use_cudnn.h create mode 100644 tensorflow/core/util/util.cc create mode 100644 tensorflow/core/util/util.h create mode 100644 tensorflow/core/util/work_sharder.cc create mode 100644 tensorflow/core/util/work_sharder.h create mode 100644 tensorflow/core/util/work_sharder_test.cc (limited to 'tensorflow/core') diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD new file mode 100644 index 0000000000..c2fcfeed8c --- /dev/null +++ b/tensorflow/core/BUILD @@ -0,0 +1,695 @@ +# Description: +# TensorFlow is a computational framework, primarily for use in machine +# learning applications. + +package(default_visibility = ["//tensorflow:internal"]) + +package_group(name = "friends") + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("/tensorflow/tensorflow", "tf_copts") +load("/tensorflow/tensorflow", "tf_cc_tests") +load("/tensorflow/tensorflow", "tf_cuda_library") +load("/tensorflow/tensorflow", "tf_gen_op_libs") +load("/tensorflow/tensorflow", "tf_gpu_kernel_library") + +# For platform specific build config +load( + "/tensorflow/core/platform/default/build_config", + "tf_proto_library", + "tf_additional_lib_srcs", + "tf_additional_test_srcs", + "tf_kernel_tests_linkstatic", +) +load( + "/tensorflow/core/platform/default/build_config_root", + "tf_cuda_tests_tags", +) + +cc_library( + name = "lib", + srcs = glob( + [ + "lib/**/*.h", + "lib/**/*.cc", + "platform/*.h", + "platform/*.cc", + "public/*.h", + ] + tf_additional_lib_srcs(), + exclude = [ + "**/*test*", + ], + ), + copts = tf_copts(), + visibility = [ + ":friends", + "//tensorflow:internal", + ], + deps = [ + ":protos_cc", + "//tensorflow/core/platform/default/build_config:platformlib", + ], +) + +tf_cuda_library( + name = "core_cpu", + srcs = glob( + [ + "common_runtime/**/*.h", + "client/**/*.cc", + "common_runtime/**/*.cc", + "graph/**/*.h", + "graph/**/*.cc", + ], + exclude = [ + "**/*test*", + "**/*main.cc", + "common_runtime/gpu/*.cc", + "common_runtime/copy_tensor.cc", + "common_runtime/gpu_device_factory.cc", + "common_runtime/local_session.cc", + "common_runtime/local_session.h", + ], + ), + hdrs = glob(["public/**/*.h"]), + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":copy_tensor", + ":framework", + ":lib", + ":protos_cc", + "//third_party/eigen3", + ], + alwayslink = 1, +) + +tf_cuda_library( + name = "framework", + srcs = glob( + [ + "framework/**/*.h", + "framework/**/*.cc", + "util/**/*.h", + "util/**/*.cc", + ], + exclude = [ + "**/*test*", + "**/*main.cc", + ], + ), + hdrs = glob(["public/**/*.h"]), + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":lib", + ":protos_cc", + "//third_party/eigen3", + ], + alwayslink = 1, +) + +tf_cuda_library( + name = "local", + srcs = [ + "common_runtime/local_session.cc", + "common_runtime/local_session.h", + ], + copts = tf_copts(), + cuda_deps = [ + ":cuda", + ], + linkstatic = 1, + deps = [ + ":core", + ":lib", + ], + alwayslink = 1, +) + +cc_library( + name = "copy_tensor", + deps = [ + ":lib", + ":protos_cc", + ":stream_executor", + "//third_party/eigen3", + ], +) + +tf_cuda_library( + name = "gpu_runtime", + srcs = glob( + [ + "common_runtime/gpu/**/*.h", + "common_runtime/gpu/**/*.cc", + ], + exclude = [ + "**/*main.cc", + "**/*test.cc", + ], + ), + copts = tf_copts(), + cuda_deps = [ + ":cuda", + ], + linkstatic = 1, + deps = [ + ":core_cpu", + ":lib", + ":protos_cc", + ":stream_executor", + "//third_party/eigen3", + ], + alwayslink = 1, +) + +# Test support library needed for higher-level tests +cc_library( + name = "testlib", + testonly = 1, + srcs = [ + "common_runtime/kernel_benchmark_testlib.cc", + "common_runtime/kernel_benchmark_testlib.h", + "framework/function_testlib.cc", + "framework/function_testlib.h", + "framework/tensor_testutil.cc", + "framework/tensor_testutil.h", + "graph/testlib.cc", + "graph/testlib.h", + ], + copts = tf_copts(), + visibility = [ + ":friends", + "//tensorflow:internal", + ], + deps = [ + ":core_cpu", + ":tensorflow", + ":test", + "//tensorflow/core/platform/default/build_config:gtest", + ], +) + +tf_cuda_library( + name = "tensorflow_opensource", + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":core", + ":gpu_runtime", + ":kernels", + ":lib", + ":local", + ], +) + +tf_cuda_library( + name = "kernels", + srcs = glob( + [ + "kernels/**/*.h", + "kernels/**/*.cc", + "ops/**/*.h", + "ops/**/*.cc", + "user_ops/**/*.h", + "user_ops/**/*.cc", + ], + exclude = [ + "**/*test*", + "**/*main.cc", + "kernels/**/*.cu.cc", + "user_ops/**/*.cu.cc", + ], + ), + copts = tf_copts(), + cuda_deps = [ + ":gpu_kernels", + ":cuda", + ], + linkstatic = 0, + visibility = ["//visibility:public"], + deps = [ + "@gemmlowp//:eight_bit_int_gemm", + ":core", + ":lib", + ":protos_cc", + ":stream_executor", + "//tensorflow/models/embedding:word2vec_kernels", + "//tensorflow/models/embedding:word2vec_ops", + "//third_party/eigen3", + ], + alwayslink = 1, +) + +tf_gpu_kernel_library( + name = "gpu_kernels", + srcs = glob( + [ + "kernels/**/*.h", + "kernels/*.cu.cc", + "user_ops/**/*.h", + "user_ops/*.cu.cc", + ], + ), + visibility = ["//visibility:public"], + deps = [ + "//third_party/eigen3", + ], +) + +# Test support library needed for all tests +cc_library( + name = "test", + testonly = 1, + srcs = [ + "platform/test.cc", + ] + tf_additional_test_srcs(), + hdrs = [ + "platform/test.h", + "platform/test_benchmark.h", + ], + copts = tf_copts(), + linkopts = ["-lm"], + deps = [ + ":lib", + "//tensorflow/core/platform/default/build_config:gtest", + ], +) + +# Main program for tests +cc_library( + name = "test_main", + testonly = 1, + srcs = ["platform/test_main.cc"], + copts = tf_copts(), + linkopts = ["-lm"], + deps = [ + ":test", + "//tensorflow/core/platform/default/build_config:test_main", + ], +) + +# TODO(opensource): Make it work externally +tf_proto_library( + name = "protos_all", + srcs = glob(["**/*.proto"]), + cc_api_version = 2, + go_api_version = 2, + java_api_version = 2, + py_api_version = 2, + visibility = ["//tensorflow:internal"], +) + +cc_library( + name = "protos_cc", + deps = ["//tensorflow/core/platform/default/build_config:protos_cc"], +) + +# Generates library per group of ops. +tf_gen_op_libs( + op_lib_names = [ + "array_ops", + "attention_ops", + "candidate_sampling_ops", + "control_flow_ops", + "data_flow_ops", + "image_ops", + "io_ops", + "linalg_ops", + "logging_ops", + "math_ops", + "nn_ops", + "no_op", + "parsing_ops", + "random_ops", + "sendrecv_ops", + "sparse_ops", + "state_ops", + "string_ops", + "summary_ops", + "training_ops", + ], +) + +# And one for all user ops +cc_library( + name = "user_ops_op_lib", + srcs = glob(["user_ops/**/*.cc"]), + copts = tf_copts(), + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [":framework"], + alwayslink = 1, +) + +# Low level library tests +tf_cc_tests( + tests = glob( + [ + "lib/**/*_test.cc", + "platform/**/*_test.cc", + ], + exclude = ["lib/strings/ordered_code_test.cc"], + ), + deps = [ + ":lib", + ":test_main", + ], +) + +cc_test( + name = "lib_jpeg_jpeg_mem_unittest", + srcs = ["lib/jpeg/jpeg_mem_unittest.cc"], + data = glob(["lib/jpeg/testdata/*.jpg"]), + deps = [ + ":lib", + ":test_main", + ], +) + +cc_test( + name = "lib_strings_ordered_code_test", + srcs = ["lib/strings/ordered_code_test.cc"], + copts = ["$(STACK_FRAME_UNLIMITED)"], # Tests initialize large vectors + deps = [ + ":lib", + ":test_main", + ], +) + +# higher level tests +tf_cc_tests( + linkstatic = tf_kernel_tests_linkstatic(), + tests = glob( + [ + "client/**/*_test.cc", + "common_runtime/**/*_test.cc", + "framework/**/*_test.cc", + "graph/**/*_test.cc", + "util/**/*_test.cc", + ], + exclude = [ + # TODO(opensource): fix + "common_runtime/gpu/*_test.cc", + # Run by tests below + "common_runtime/gpu/gpu_region_allocator_test.cc", + "common_runtime/gpu/gpu_bfc_allocator_test.cc", + ], + ), + deps = [ + ":core", + ":kernels", + ":lib", + ":local", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + ], +) + +# GPU-related tests +tf_cc_tests( + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + tests = glob( + [ + "kernels/**/*_test.cc", + "user_ops/**/*_test.cc", + "common_runtime/gpu/*_test.cc", + ], + ), + deps = [ + ":kernels", + ":local", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + ], +) + +tf_cuda_library( + name = "stream_executor", + deps = [ + "//tensorflow/core/platform/default/build_config:stream_executor", + ], +) + +cc_library( + name = "cuda", + visibility = [ + ":friends", + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/core/platform/default/build_config:cuda", + ], +) + +cc_library( + name = "tensorflow", + visibility = ["//visibility:public"], + deps = [ + "tensorflow_opensource", + "//tensorflow/core/platform/default/build_config:tensorflow_platform_specific", + ], +) + +cc_library( + name = "core", + visibility = ["//visibility:public"], + deps = [ + ":core_cpu", + ":gpu_runtime", + ], +) + +# Android-specific BUILD targets +load("/tensorflow/tensorflow", "tf_android_core_proto_sources") + +# List of protos we want on android +filegroup( + name = "android_proto_srcs", + srcs = tf_android_core_proto_sources(), + visibility = ["//visibility:public"], +) + +# Core sources. Should eventually become identical to open source +# sources. +filegroup( + name = "android_srcs", + srcs = glob( + [ + "client/**/*.cc", + "common_runtime/**/*.h", + "common_runtime/**/*.cc", + "framework/**/*.h", + "framework/**/*.cc", + "graph/**/*.h", + "graph/**/*.cc", + "lib/**/*.h", + "lib/**/*.cc", + "ops/**/*.cc", + "ops/**/*.h", + "platform/*.h", + "platform/*.cc", + "platform/**/*.h", + "platform/**/*.cc", + "public/**/*.h", + "util/**/*.h", + "util/**/*.cc", + "kernels/ops_util.cc", + "kernels/ops_util.h", + "kernels/avgpooling_op.h", + "kernels/maxpooling_op.h", + "kernels/pooling_ops_common.h", + "kernels/pooling_ops_common.cc", + "kernels/reference_gemm.h", + ], + exclude = [ + "**/*test.cc", + "**/*testutil*", + "**/*testlib*", + "**/*main.cc", + "lib/jpeg/*.h", + "lib/jpeg/*.cc", + "lib/png/*.h", + "lib/png/*.cc", + "util/events_writer.cc", + "util/events_writer.h", + # Exclude all protobuf/google headers except protobuf_android.h + "platform/google/cord_coding.h", + "platform/google/dynamic_annotations.h", + "platform/google/integral_types.h", + "platform/google/mutex.h", + "platform/google/protobuf.h", + "platform/google/stream_executor_util.h", + "platform/google/tracing_impl.h", + "platform/google/*.cc", + "platform/google/test_benchmark.cc", + "platform/google/test_benchmark.h", + "kernels/**/*.cu.cc", + "user_ops/**/*.cu.cc", + "common_runtime/gpu/*.cc", + "common_runtime/gpu_device_factory.cc", + ], + ), + visibility = ["//visibility:public"], +) + +# Core kernels we want on Android. Only a subset of kernels to keep +# base library small. +filegroup( + name = "android_core_ops", + srcs = [ + "//tensorflow/core:kernels/aggregate_ops.cc", + "//tensorflow/core:kernels/aggregate_ops.h", + "//tensorflow/core:kernels/assign_op.h", + "//tensorflow/core:kernels/bias_op.cc", + "//tensorflow/core:kernels/bias_op.h", + "//tensorflow/core:kernels/cast_op.cc", + "//tensorflow/core:kernels/cast_op.h", + "//tensorflow/core:kernels/concat_op.cc", + "//tensorflow/core:kernels/concat_op.h", + "//tensorflow/core:kernels/concat_op_cpu.cc", + "//tensorflow/core:kernels/constant_op.cc", + "//tensorflow/core:kernels/constant_op.h", + "//tensorflow/core:kernels/cwise_ops.h", + "//tensorflow/core:kernels/cwise_ops_common.cc", + "//tensorflow/core:kernels/cwise_ops_common.h", + "//tensorflow/core:kernels/dense_update_ops.cc", + "//tensorflow/core:kernels/dense_update_ops.h", + "//tensorflow/core:kernels/fill_functor.h", + "//tensorflow/core:kernels/gather_op.cc", + "//tensorflow/core:kernels/identity_op.cc", + "//tensorflow/core:kernels/identity_op.h", + "//tensorflow/core:kernels/matmul_op.cc", + "//tensorflow/core:kernels/matmul_op.h", + "//tensorflow/core:kernels/no_op.cc", + "//tensorflow/core:kernels/no_op.h", + "//tensorflow/core:kernels/pack_op.cc", + "//tensorflow/core:kernels/reference_gemm.h", + "//tensorflow/core:kernels/reshape_op.cc", + "//tensorflow/core:kernels/reshape_op.h", + "//tensorflow/core:kernels/reverse_sequence_op.cc", + "//tensorflow/core:kernels/reverse_sequence_op.h", + "//tensorflow/core:kernels/sendrecv_ops.cc", + "//tensorflow/core:kernels/sendrecv_ops.h", + "//tensorflow/core:kernels/sequence_ops.cc", + "//tensorflow/core:kernels/shape_ops.cc", + "//tensorflow/core:kernels/slice_op.cc", + "//tensorflow/core:kernels/slice_op.h", + "//tensorflow/core:kernels/softmax_op.cc", + "//tensorflow/core:kernels/softmax_op.h", + "//tensorflow/core:kernels/split_op.cc", + "//tensorflow/core:kernels/split_op.h", + "//tensorflow/core:kernels/split_op_cpu.cc", + "//tensorflow/core:kernels/unpack_op.cc", + "//tensorflow/core:kernels/variable_ops.cc", + "//tensorflow/core:kernels/variable_ops.h", + ], + visibility = ["//visibility:public"], +) + +# Other kernels we may want on Android. +filegroup( + name = "android_extended_ops", + srcs = [ + "//tensorflow/core:kernels/avgpooling_op.cc", + "//tensorflow/core:kernels/avgpooling_op.h", + "//tensorflow/core:kernels/control_flow_ops.cc", + "//tensorflow/core:kernels/control_flow_ops.h", + "//tensorflow/core:kernels/conv_2d.h", + "//tensorflow/core:kernels/conv_ops.cc", + "//tensorflow/core:kernels/cwise_op_add.cc", + "//tensorflow/core:kernels/cwise_op_div.cc", + "//tensorflow/core:kernels/cwise_op_exp.cc", + "//tensorflow/core:kernels/cwise_op_log.cc", + "//tensorflow/core:kernels/cwise_op_mul.cc", + "//tensorflow/core:kernels/cwise_op_sigmoid.cc", + "//tensorflow/core:kernels/cwise_op_sqrt.cc", + "//tensorflow/core:kernels/cwise_op_square.cc", + "//tensorflow/core:kernels/cwise_op_sub.cc", + "//tensorflow/core:kernels/cwise_op_tanh.cc", + "//tensorflow/core:kernels/lrn_op.cc", + "//tensorflow/core:kernels/maxpooling_op.cc", + "//tensorflow/core:kernels/maxpooling_op.h", + "//tensorflow/core:kernels/reduction_ops.h", + "//tensorflow/core:kernels/reduction_ops_common.h", + "//tensorflow/core:kernels/reduction_ops_max.cc", + "//tensorflow/core:kernels/reduction_ops_min.cc", + "//tensorflow/core:kernels/reduction_ops_sum.cc", + "//tensorflow/core:kernels/relu_op.cc", + "//tensorflow/core:kernels/relu_op.h", + "//tensorflow/core:kernels/softplus_op.cc", + "//tensorflow/core:kernels/softplus_op.h", + "//tensorflow/core:kernels/transpose_op.cc", + "//tensorflow/core:kernels/transpose_op.h", + "//tensorflow/core:kernels/transpose_op_functor.h", + ], + visibility = ["//visibility:public"], +) + +# Test data +filegroup( + name = "image_testdata", + srcs = [ + # PNG data + "lib/png/testdata/lena_gray.png", + "lib/png/testdata/lena_rgba.png", + # JPEG data + "lib/jpeg/testdata/jpeg_merge_test1.jpg", + "lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg", + # Corrupted JPEG files for tests + "lib/jpeg/testdata/bad_huffman.jpg", + "lib/jpeg/testdata/corrupt.jpg", + # -- hand-edited variant: stops at line 0 + "lib/jpeg/testdata/corrupt34_2.jpg", + # -- hand-edited variant: stops at line 4 + "lib/jpeg/testdata/corrupt34_3.jpg", + # -- hand-edited variant: stops after a restart marker + "lib/jpeg/testdata/corrupt34_4.jpg", + ], +) + +# For portable_proto_library + +# Native library support for Android applications. +# Should be built to target Android with flag --copt=-mfpu=neon. +cc_library( + name = "android_tensorflow_lib", + srcs = [ + "//tensorflow/core:android_core_ops", + "//tensorflow/core:android_extended_ops", + "//tensorflow/core:android_srcs", + ], + copts = [ + "-mfpu=neon", + "-std=c++11", + ], + tags = [ + "manual", + "notap", + ], + visibility = ["//visibility:public"], + deps = [ + "@re2//:re2", + ":protos_cc", + "//third_party/eigen3", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc new file mode 100644 index 0000000000..59cf0ed8f9 --- /dev/null +++ b/tensorflow/core/client/tensor_c_api.cc @@ -0,0 +1,370 @@ +#include "tensorflow/core/public/tensor_c_api.h" + +#include + +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +// The implementation below is at the top level instead of the +// brain namespace because we are defining 'extern "C"' functions. +using tensorflow::error::Code; +using tensorflow::errors::InvalidArgument; +using tensorflow::gtl::ArraySlice; +using tensorflow::AllocationDescription; +using tensorflow::Status; +using tensorflow::DataType; +using tensorflow::Env; +using tensorflow::GraphDef; +using tensorflow::NewSession; +using tensorflow::Session; +using tensorflow::Tensor; +using tensorflow::TensorBuffer; +using tensorflow::SessionOptions; +using tensorflow::TensorShape; + +extern "C" { + +// -------------------------------------------------------------------------- +struct TF_Status { + Status status; +}; + +TF_Status* TF_NewStatus() { return new TF_Status; } + +void TF_DeleteStatus(TF_Status* s) { delete s; } + +void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) { + s->status = Status(static_cast(code), tensorflow::StringPiece(msg)); +} + +TF_Code TF_GetCode(const TF_Status* s) { + return static_cast(s->status.code()); +} + +const char* TF_Message(const TF_Status* s) { + return s->status.error_message().c_str(); +} + +// -------------------------------------------------------------------------- + +namespace { +class TF_ManagedBuffer : public TensorBuffer { + public: + void* data_; + size_t len_; + void (*deallocator_)(void* data, size_t len, void* arg); + void* deallocator_arg_; + + ~TF_ManagedBuffer() override { + (*deallocator_)(data_, len_, deallocator_arg_); + } + + void* data() const override { return data_; } + size_t size() const override { return len_; } + TensorBuffer* root_buffer() override { return this; } + void FillAllocationDescription(AllocationDescription* proto) const override { + tensorflow::int64 rb = size(); + proto->set_requested_bytes(rb); + proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); + } +}; + +void deallocate_realigned_buffer(void* data, size_t len, void* arg) { + tensorflow::cpu_allocator()->DeallocateRaw(data); +} +} // namespace + +struct TF_Tensor { + TF_DataType dtype; + TensorShape shape; + TensorBuffer* buffer; +}; + +TF_Tensor* TF_NewTensor(TF_DataType dtype, tensorflow::int64* dims, + int num_dims, void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg) { + std::vector dimvec(num_dims); + for (int i = 0; i < num_dims; i++) { + dimvec[i] = dims[i]; + } + + TF_ManagedBuffer* buf = new TF_ManagedBuffer; + buf->len_ = len; + if (reinterpret_cast(data) % EIGEN_MAX_ALIGN_BYTES != 0) { + // Copy the data into a buffer that satisfies Eigen's alignment + // requirements. + buf->data_ = + tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len); + std::memcpy(buf->data_, data, len); + buf->deallocator_ = deallocate_realigned_buffer; + buf->deallocator_arg_ = nullptr; + // Free the original buffer. + deallocator(data, len, deallocator_arg); + } else { + buf->data_ = data; + buf->deallocator_ = deallocator; + buf->deallocator_arg_ = deallocator_arg; + } + return new TF_Tensor{dtype, TensorShape(dimvec), buf}; +} + +void TF_DeleteTensor(TF_Tensor* t) { + t->buffer->Unref(); + delete t; +} + +TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; } +int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); } +tensorflow::int64 TF_Dim(const TF_Tensor* t, int dim_index) { + return t->shape.dim_size(dim_index); +} +size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); } +void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); } + +// -------------------------------------------------------------------------- +struct TF_SessionOptions { + SessionOptions options; +}; +TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; } +void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; } + +void TF_SetTarget(TF_SessionOptions* options, const char* target) { + options->options.target = target; +} + +void TF_SetConfig(TF_SessionOptions* options, const char* config, + size_t config_len, TF_Status* status) { + if (!options->options.config.ParseFromArray(config, config_len)) { + status->status = + tensorflow::errors::InvalidArgument("Unparseable ConfigProto"); + } +} + +// -------------------------------------------------------------------------- +struct TF_Session { + Session* session; +}; + +TF_Session* TF_NewSession(const TF_SessionOptions* opt, TF_Status* status) { + Session* session; + status->status = NewSession(opt->options, &session); + if (status->status.ok()) { + return new TF_Session({session}); + } else { + DCHECK_EQ(nullptr, session); + return NULL; + } +} + +void TF_CloseSession(TF_Session* s, TF_Status* status) { + status->status = s->session->Close(); +} + +void TF_DeleteSession(TF_Session* s, TF_Status* status) { + status->status = Status::OK(); + delete s->session; + delete s; +} + +void TF_ExtendGraph(TF_Session* s, const void* proto, size_t proto_len, + TF_Status* status) { + GraphDef g; + if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument("Invalid GraphDef"); + return; + } + status->status = s->session->Extend(g); +} + +static void DeleteArray(void* data, size_t size, void* arg) { + DCHECK_EQ(data, arg); + delete[] reinterpret_cast(arg); +} + +} // end extern "C" + +namespace tensorflow { + +// Non-static for testing. +bool TF_Tensor_DecodeStrings(TF_Tensor* src, Tensor* dst, TF_Status* status) { + const tensorflow::int64 num_elements = src->shape.num_elements(); + const char* input = reinterpret_cast(TF_TensorData(src)); + const size_t src_size = TF_TensorByteSize(src); + if (static_cast(src_size / sizeof(tensorflow::uint64)) < + num_elements) { + status->status = InvalidArgument( + "Malformed TF_STRING tensor; too short to hold number of elements"); + return false; + } + const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; + const char* limit = input + src_size; + + *dst = Tensor(static_cast(src->dtype), src->shape); + auto dstarray = dst->flat(); + for (tensorflow::int64 i = 0; i < num_elements; i++) { + tensorflow::uint64 offset = + reinterpret_cast(input)[i]; + tensorflow::uint64 len; + const char* p; + if (static_cast(offset) >= (limit - data_start) || + !(p = tensorflow::core::GetVarint64Ptr(data_start + offset, limit, + &len)) || + (static_cast(len) > (limit - p))) { + status->status = InvalidArgument("Malformed TF_STRING tensor; element ", + i, " out of range"); + return false; + } + dstarray(i).assign(p, len); + } + return true; +} + +// Non-static for testing. +TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src) { + // Compute bytes needed for encoding. + size_t size = 0; + const auto& srcarray = src.flat(); + for (int i = 0; i < srcarray.size(); i++) { + const tensorflow::string& s = srcarray(i); + // uint64 starting_offset, varint64 length, string contents + size += sizeof(tensorflow::uint64) + + tensorflow::core::VarintLength(s.size()) + s.size(); + } + + // Encode all strings. + char* base = new char[size]; + char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size(); + char* dst = data_start; // Where next string is encoded. + tensorflow::uint64* offsets = reinterpret_cast(base); + for (int i = 0; i < srcarray.size(); i++) { + const tensorflow::string& s = srcarray(i); + *offsets = (dst - data_start); + offsets++; + dst = tensorflow::core::EncodeVarint64(dst, s.size()); + memcpy(dst, s.data(), s.size()); + dst += s.size(); + } + CHECK_EQ(dst, base + size); + + auto dims = src.shape().dim_sizes(); + std::vector dimvec(dims.size()); + for (size_t i = 0; i < dims.size(); i++) { + dimvec[i] = dims[i]; + } + return TF_NewTensor(TF_STRING, dimvec.data(), dimvec.size(), base, size, + DeleteArray, base); +} + +class TensorCApi { + public: + static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; } + static Tensor MakeTensor(TF_DataType type, const TensorShape& shape, + TensorBuffer* buf) { + return Tensor(static_cast(type), shape, buf); + } +}; + +// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to +// result in a zero-sized tensor. +static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { + static char empty; + tensorflow::int64 nelems = 1; + std::vector dims; + for (int i = 0; i < shape.dims(); ++i) { + dims.push_back(shape.dim_size(i)); + nelems *= shape.dim_size(i); + } + CHECK_EQ(nelems, 0); + return TF_NewTensor(dtype, dims.data(), shape.dims(), + reinterpret_cast(&empty), 0, + [](void*, size_t, void*) {}, nullptr); +} + +} // namespace tensorflow + +extern "C" { + +void TF_Run(TF_Session* s, + // Input tensors + const char** c_input_names, TF_Tensor** c_inputs, int ninputs, + // Output tensors + const char** c_output_tensor_names, TF_Tensor** c_outputs, + int noutputs, + // Target nodes + const char** c_target_node_names, int ntargets, TF_Status* status) { + status->status = Status::OK(); + for (int i = 0; i < noutputs; i++) { + c_outputs[i] = NULL; + } + + // Initialize inputs. + std::vector> inputs(ninputs); + bool ok = true; + for (int i = 0; i < ninputs; i++) { + TF_Tensor* src = c_inputs[i]; + if (ok) { + inputs[i].first = c_input_names[i]; + if (c_inputs[i]->dtype != TF_STRING) { + inputs[i].second = tensorflow::TensorCApi::MakeTensor( + src->dtype, src->shape, src->buffer); + } else { + // TF_STRING tensors require copying since Tensor class expects + // a sequence of string objects. + ok = + tensorflow::TF_Tensor_DecodeStrings(src, &inputs[i].second, status); + // Must keep looping through all inputs even if there is an error + // so that TF_DeleteTensor() is called unconditionally on all inputs. + } + } + TF_DeleteTensor(src); + } + if (!ok) { + return; + } + + std::vector output_tensor_names(noutputs); + std::vector outputs(noutputs); + std::vector target_node_names(ntargets); + for (int i = 0; i < noutputs; i++) { + output_tensor_names[i] = c_output_tensor_names[i]; + } + for (int i = 0; i < ntargets; i++) { + target_node_names[i] = c_target_node_names[i]; + } + Status result = + s->session->Run(inputs, output_tensor_names, target_node_names, &outputs); + if (!result.ok()) { + status->status = result; + return; + } + + // Store results in c_outputs[] + for (int i = 0; i < noutputs; i++) { + const Tensor& src = outputs[i]; + if (!src.IsInitialized()) { + c_outputs[i] = tensorflow::EmptyTensor( + static_cast(src.dtype()), src.shape()); + continue; + } + if (src.dtype() != tensorflow::DT_STRING) { + // Share the underlying buffer. + TensorBuffer* buf = tensorflow::TensorCApi::Buffer(src); + buf->Ref(); + c_outputs[i] = new TF_Tensor{static_cast(src.dtype()), + src.shape(), buf}; + } else { + c_outputs[i] = tensorflow::TF_Tensor_EncodeStrings(src); + } + } +} + +} // end extern "C" diff --git a/tensorflow/core/client/tensor_c_api_test.cc b/tensorflow/core/client/tensor_c_api_test.cc new file mode 100644 index 0000000000..4afdd0c0df --- /dev/null +++ b/tensorflow/core/client/tensor_c_api_test.cc @@ -0,0 +1,94 @@ +#include "tensorflow/core/public/tensor_c_api.h" + +#include +#include "tensorflow/core/public/tensor.h" + +using tensorflow::Tensor; +using tensorflow::TensorShape; + +namespace tensorflow { +bool TF_Tensor_DecodeStrings(TF_Tensor* src, Tensor* dst, TF_Status* status); +TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src); +} // namespace tensorflow + +TEST(CApi, Status) { + TF_Status* s = TF_NewStatus(); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + EXPECT_EQ(tensorflow::string(), TF_Message(s)); + TF_SetStatus(s, TF_CANCELLED, "cancel"); + EXPECT_EQ(TF_CANCELLED, TF_GetCode(s)); + EXPECT_EQ(tensorflow::string("cancel"), TF_Message(s)); + TF_DeleteStatus(s); +} + +static void Deallocator(void* data, size_t, void* arg) { + tensorflow::cpu_allocator()->DeallocateRaw(data); + *reinterpret_cast(arg) = true; +} + +TEST(CApi, Tensor) { + float* values = + reinterpret_cast(tensorflow::cpu_allocator()->AllocateRaw( + EIGEN_MAX_ALIGN_BYTES, 6 * sizeof(float))); + tensorflow::int64 dims[] = {2, 3}; + bool deallocator_called = false; + TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, sizeof(values), + &Deallocator, &deallocator_called); + EXPECT_FALSE(deallocator_called); + EXPECT_EQ(TF_FLOAT, TF_TensorType(t)); + EXPECT_EQ(2, TF_NumDims(t)); + EXPECT_EQ(dims[0], TF_Dim(t, 0)); + EXPECT_EQ(dims[1], TF_Dim(t, 1)); + EXPECT_EQ(sizeof(values), TF_TensorByteSize(t)); + EXPECT_EQ(static_cast(values), TF_TensorData(t)); + TF_DeleteTensor(t); + EXPECT_TRUE(deallocator_called); +} + +static void TestEncodeDecode(int line, + const std::vector& data) { + const tensorflow::int64 n = data.size(); + for (std::vector dims : + std::vector>{ + {n}, {1, n}, {n, 1}, {n / 2, 2}}) { + // Create C++ Tensor + Tensor src(tensorflow::DT_STRING, TensorShape(dims)); + for (tensorflow::int64 i = 0; i < src.NumElements(); i++) { + src.flat()(i) = data[i]; + } + TF_Tensor* dst = TF_Tensor_EncodeStrings(src); + + // Convert back to a C++ Tensor and ensure we get expected output. + TF_Status* status = TF_NewStatus(); + Tensor output; + ASSERT_TRUE(TF_Tensor_DecodeStrings(dst, &output, status)) << line; + ASSERT_EQ(TF_OK, TF_GetCode(status)) << line; + ASSERT_EQ(src.NumElements(), output.NumElements()) << line; + for (tensorflow::int64 i = 0; i < src.NumElements(); i++) { + ASSERT_EQ(data[i], output.flat()(i)) << line; + } + + TF_DeleteStatus(status); + TF_DeleteTensor(dst); + } +} + +TEST(CApi, TensorEncodeDecodeStrings) { + TestEncodeDecode(__LINE__, {}); + TestEncodeDecode(__LINE__, {"hello"}); + TestEncodeDecode(__LINE__, + {"the", "quick", "brown", "fox", "jumped", "over"}); + + tensorflow::string big(1000, 'a'); + TestEncodeDecode(__LINE__, {"small", big, "small2"}); +} + +TEST(CApi, SessionOptions) { + TF_SessionOptions* opt = TF_NewSessionOptions(); + TF_DeleteSessionOptions(opt); +} + +// TODO(jeff,sanjay): Session tests +// . Create and delete +// . Extend graph +// . Run diff --git a/tensorflow/core/common_runtime/device.cc b/tensorflow/core/common_runtime/device.cc new file mode 100644 index 0000000000..2e3e7b6597 --- /dev/null +++ b/tensorflow/core/common_runtime/device.cc @@ -0,0 +1,37 @@ +#include "tensorflow/core/common_runtime/device.h" + +#include "tensorflow/core/framework/op_segment.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +Device::Device(Env* env, const DeviceAttributes& device_attributes, + Allocator* device_allocator) + : DeviceBase(env), device_attributes_(device_attributes) { + CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_)) + << "Invalid device name: " << name(); + rmgr_ = new ResourceMgr(parsed_name_.job); +} + +Device::~Device() { delete rmgr_; } + +// static +DeviceAttributes Device::BuildDeviceAttributes( + const string& name, DeviceType device, Bytes memory_limit, + BusAdjacency bus_adjacency, const string& physical_device_desc) { + DeviceAttributes da; + da.set_name(name); + do { + da.set_incarnation(random::New64()); + } while (da.incarnation() == 0); // This proto field must not be zero + da.set_device_type(device.type()); + da.set_memory_limit(memory_limit.value()); + da.set_bus_adjacency(bus_adjacency); + da.set_physical_device_desc(physical_device_desc); + return da; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h new file mode 100644 index 0000000000..ff3404fea4 --- /dev/null +++ b/tensorflow/core/common_runtime/device.h @@ -0,0 +1,128 @@ +// A Device is a something that can perform computations as part of a +// model. Devices can be local (runs computation on this machine), or +// remote (contacts a device local to another machine using an RPC to +// do the work). Devices are registered in a DeviceSet, which is also +// responsible for the Device <-> id mapping. +// +// Device names +// * Every Device should have a unique name with the format: +// /job:___/replica:___/task:___/(gpu|cpu):___ +// An example name would be "/job:train/replica:0/task:3/gpu:2". +// * Task numbers are within the specified replica, so there are as +// many "task zeros" as replicas. + +#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_H_ +#define TENSORFLOW_COMMON_RUNTIME_DEVICE_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_segment.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +class Device : public DeviceBase { + public: + Device(Env* env, const DeviceAttributes& device_attributes, + Allocator* device_allocator); + ~Device() override; + + // Full name of this device (see top comment). + const string& name() const { return device_attributes_.name(); } + + // Parsed name of this device + const DeviceNameUtils::ParsedName parsed_name() const { return parsed_name_; } + + // Describes what kind of device this is. This is intended to be + // human-readable and not computer-parsed, except that two devices + // with the same device_type() are expected to perform similarly + // (both from a computation and communication perspective). + const string& device_type() const { return device_attributes_.device_type(); } + + // Returns an aggregation of device attributes. + const DeviceAttributes& attributes() const override { + return device_attributes_; + } + + // Performs the actual compute function. + // + // Subclasses may override this function if they wish to perform + // some initialization before each compute. + virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) { + op_kernel->Compute(context); + } + + // Asynchronous kernel's compute. + virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) { + op_kernel->ComputeAsync(context, done); + } + + // Blocks until all operations queued on the device at the time of + // the call have completed. Returns any error pending on the device + // at completion. + virtual Status Sync() = 0; + + // Fill in the context map for the graph. Default behavior is to do + // nothing. + // + // The caller takes ownership over the DeviceContext objects given + // by the device. + virtual Status FillContextMap(const Graph* graph, + DeviceContextMap* device_context_map) { + return Status::OK(); + } + + // Returns the op segment of this device. The caller can reuse op + // kernels registered for the same session running on this device. + OpSegment* op_segment() { return &op_seg_; } + + // Returns the resource manager associated w/ this device. + ResourceMgr* resource_manager() { return rmgr_; } + + // Summarizes the status of this Device, for debugging. + string DebugString() const { return device_attributes_.DebugString(); } + + // Assembles the parameter components into a complete DeviceAttributes value. + static DeviceAttributes BuildDeviceAttributes( + const string& name, DeviceType device, Bytes memory_limit, + BusAdjacency bus_adjacency, const string& physical_device_desc); + + static DeviceAttributes BuildDeviceAttributes(const string& name, + DeviceType device, + Bytes memory_limit, + BusAdjacency bus_adjacency) { + // Pass in an empty string as physical device name. + return BuildDeviceAttributes(name, device, memory_limit, bus_adjacency, ""); + } + + private: + const DeviceAttributes device_attributes_; + DeviceNameUtils::ParsedName parsed_name_; + + // op_seg_ maps session handle and op name to OpKernel objects. + OpSegment op_seg_; + + // Resources associated w/ this device. E.g., shared variables, etc. + ResourceMgr* rmgr_ = nullptr; + + TF_DISALLOW_COPY_AND_ASSIGN(Device); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_H_ diff --git a/tensorflow/core/common_runtime/device_factory.cc b/tensorflow/core/common_runtime/device_factory.cc new file mode 100644 index 0000000000..7d391bde1d --- /dev/null +++ b/tensorflow/core/common_runtime/device_factory.cc @@ -0,0 +1,106 @@ +#include "tensorflow/core/common_runtime/device_factory.h" + +#include +#include +#include + +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +namespace { + +static mutex* get_device_factory_lock() { + static mutex device_factory_lock; + return &device_factory_lock; +} + +struct FactoryItem { + std::unique_ptr factory; + int priority; +}; + +std::unordered_map& device_factories() { + static std::unordered_map* factories = + new std::unordered_map; + return *factories; +} +} // namespace + +void DeviceFactory::Register(const string& device_type, DeviceFactory* factory, + int priority) { + mutex_lock l(*get_device_factory_lock()); + std::unique_ptr factory_ptr(factory); + std::unordered_map& factories = device_factories(); + auto iter = factories.find(device_type); + if (iter == factories.end()) { + factories[device_type] = {std::move(factory_ptr), priority}; + } else { + if (iter->second.priority < priority) { + iter->second = {std::move(factory_ptr), priority}; + } else if (iter->second.priority == priority) { + LOG(FATAL) << "Duplicate registration of device factory for type " + << device_type << " with the same priority " << priority; + } + } +} + +DeviceFactory* DeviceFactory::GetFactory(const string& device_type) { + mutex_lock l(*get_device_factory_lock()); // could use reader lock + auto it = device_factories().find(device_type); + if (it == device_factories().end()) { + return nullptr; + } + return it->second.factory.get(); +} + +void DeviceFactory::AddDevices(const SessionOptions& options, + const string& name_prefix, + std::vector* devices) { + // CPU first. + auto cpu_factory = GetFactory("CPU"); + if (!cpu_factory) { + LOG(FATAL) + << "CPU Factory not registered. Did you link in threadpool_device?"; + } + size_t init_size = devices->size(); + cpu_factory->CreateDevices(options, name_prefix, devices); + if (devices->size() == init_size) { + LOG(FATAL) << "No CPU devices are available in this process"; + } + + // Then GPU. + auto gpu_factory = GetFactory("GPU"); + if (gpu_factory) { + gpu_factory->CreateDevices(options, name_prefix, devices); + } + + // Then the rest. + mutex_lock l(*get_device_factory_lock()); + for (auto& p : device_factories()) { + auto factory = p.second.factory.get(); + if (factory != cpu_factory && factory != gpu_factory) { + factory->CreateDevices(options, name_prefix, devices); + } + } +} + +Device* DeviceFactory::NewDevice(const string& type, + const SessionOptions& options, + const string& name_prefix) { + auto device_factory = GetFactory(type); + if (!device_factory) { + return nullptr; + } + SessionOptions opt = options; + (*opt.config.mutable_device_count())[type] = 1; + std::vector devices; + device_factory->CreateDevices(opt, name_prefix, &devices); + CHECK_EQ(devices.size(), 1); + return devices[0]; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/device_factory.h b/tensorflow/core/common_runtime/device_factory.h new file mode 100644 index 0000000000..57b625b3e5 --- /dev/null +++ b/tensorflow/core/common_runtime/device_factory.h @@ -0,0 +1,69 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_ +#define TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_ + +#include +#include +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class Device; +struct SessionOptions; + +class DeviceFactory { + public: + virtual ~DeviceFactory() {} + static void Register(const string& device_type, DeviceFactory* factory, + int priority); + static DeviceFactory* GetFactory(const string& device_type); + + // Append to "*devices" all suitable devices, respecting + // any device type specific properties/counts listed in "options". + // + // CPU devices are added first. + static void AddDevices(const SessionOptions& options, + const string& name_prefix, + std::vector* devices); + + // Helper for tests. Create a single device of type "type". The + // returned device is always numbered zero, so if creating multiple + // devices of the same type, supply distinct name_prefix arguments. + static Device* NewDevice(const string& type, const SessionOptions& options, + const string& name_prefix); + + // Most clients should call AddDevices() instead. + virtual void CreateDevices(const SessionOptions& options, + const string& name_prefix, + std::vector* devices) = 0; +}; + +namespace dfactory { + +template +class Registrar { + public: + // Multiple registrations for the same device type with different priorities + // are allowed. The registration with the highest priority will be used. + explicit Registrar(const string& device_type, int priority = 0) { + DeviceFactory::Register(device_type, new Factory(), priority); + } +}; + +} // namespace dfactory + +#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \ + INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ + __COUNTER__, ##__VA_ARGS__) + +#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ + ctr, ...) \ + static ::tensorflow::dfactory::Registrar \ + INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \ + ##__VA_ARGS__) + +// __COUNTER__ must go through another macro to be properly expanded +#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_ + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_ diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc new file mode 100644 index 0000000000..4fa13f6b4b --- /dev/null +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -0,0 +1,90 @@ +#include "tensorflow/core/common_runtime/device_mgr.h" + +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +DeviceMgr::DeviceMgr(const std::vector& devices) { + for (Device* d : devices) { + devices_.push_back(d); + + // Register under both the full name and the local name. + device_map_[d->name()] = d; + device_map_[DeviceNameUtils::LocalName(d->name())] = d; + device_type_counts_[d->device_type()]++; + } +} + +DeviceMgr::~DeviceMgr() { + for (auto p : devices_) delete p; +} + +void DeviceMgr::ListDeviceAttributes( + std::vector* devices) const { + devices->reserve(devices_.size()); + for (Device* dev : devices_) { + devices->emplace_back(dev->attributes()); + } +} + +std::vector DeviceMgr::ListDevices() const { + return std::vector(devices_.begin(), devices_.end()); +} + +string DeviceMgr::DebugString() const { + string out; + for (Device* dev : devices_) { + strings::StrAppend(&out, dev->name(), "\n"); + } + return out; +} + +string DeviceMgr::DeviceMappingString() const { + string out; + for (Device* dev : devices_) { + if (!dev->attributes().physical_device_desc().empty()) { + strings::StrAppend(&out, dev->name(), " -> ", + dev->attributes().physical_device_desc(), "\n"); + } + } + return out; +} + +Status DeviceMgr::LookupDevice(const string& name, Device** device) const { + Status s; + auto iter = device_map_.find(name); + if (iter == device_map_.end()) { + return errors::InvalidArgument(name, " unknown device."); + } + *device = iter->second; + return Status::OK(); +} + +void DeviceMgr::ClearContainers(gtl::ArraySlice containers) const { + Status s; + for (Device* dev : devices_) { + if (containers.empty()) { + s.Update(dev->resource_manager()->Cleanup( + dev->resource_manager()->default_container())); + } else { + for (const string& c : containers) { + s.Update(dev->resource_manager()->Cleanup(c)); + } + } + if (!s.ok()) { + LOG(WARNING) << s; + } + } +} + +int DeviceMgr::NumDeviceType(const string& type) const { + auto iter = device_type_counts_.find(type); + if (iter != device_type_counts_.end()) return iter->second; + return 0; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h new file mode 100644 index 0000000000..c57d0222aa --- /dev/null +++ b/tensorflow/core/common_runtime/device_mgr.h @@ -0,0 +1,55 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_ +#define TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class DeviceAttributes; + +class DeviceMgr { + public: + // TODO(zhifengc): Other initialization information. + explicit DeviceMgr(const std::vector& devices); + ~DeviceMgr(); + + // Returns attributes of all devices. + void ListDeviceAttributes(std::vector* devices) const; + + std::vector ListDevices() const; + + // Returns a string listing all devices. + string DebugString() const; + + // Returns a string of all the device mapping. + string DeviceMappingString() const; + + // Assigns *device with pointer to Device of the given name. + // Accepts either a full device name, or just the replica-local suffix. + Status LookupDevice(const string& name, Device** device) const; + + // Clears given containers of all devices if 'container' is + // non-empty. Otherwise, clears default containers of all devices. + void ClearContainers(gtl::ArraySlice containers) const; + + int NumDeviceType(const string& type) const; + + private: + typedef gtl::InlinedVector DeviceVec; + DeviceVec devices_; + std::unordered_map device_map_; + std::unordered_map device_type_counts_; + + TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_ diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc new file mode 100644 index 0000000000..3b0465d9a6 --- /dev/null +++ b/tensorflow/core/common_runtime/device_set.cc @@ -0,0 +1,68 @@ +#include "tensorflow/core/common_runtime/device_set.h" + +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { + +DeviceSet::DeviceSet() {} + +DeviceSet::~DeviceSet() {} + +void DeviceSet::AddDevice(Device* device) { + devices_.push_back(device); + device_by_name_.insert({device->name(), device}); +} + +void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec, + std::vector* devices) const { + // TODO(jeff): If we are going to repeatedly lookup the set of devices + // for the same spec, maybe we should have a cache of some sort + devices->clear(); + for (Device* d : devices_) { + if (DeviceNameUtils::IsCompleteSpecification(spec, d->parsed_name())) { + devices->push_back(d); + } + } +} + +Device* DeviceSet::FindDeviceByName(const string& name) const { + return gtl::FindPtrOrNull(device_by_name_, name); +} + +// Higher result implies lower priority. +static int Order(const DeviceType& d) { + if (StringPiece(d.type()) == DEVICE_CPU) { + return 3; + } else if (StringPiece(d.type()) == DEVICE_GPU) { + return 2; + } else { + return 1; + } +} + +static bool ByPriority(const DeviceType& a, const DeviceType& b) { + // Order by "order number"; break ties lexicographically. + return std::make_pair(Order(a), StringPiece(a.type())) < + std::make_pair(Order(b), StringPiece(b.type())); +} + +std::vector DeviceSet::PrioritizedDeviceTypeList() const { + std::vector result; + std::set seen; + for (Device* d : devices_) { + auto t = d->device_type(); + if (seen.insert(t).second) { + result.emplace_back(DeviceType(t)); + } + } + std::sort(result.begin(), result.end(), ByPriority); + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h new file mode 100644 index 0000000000..130d965891 --- /dev/null +++ b/tensorflow/core/common_runtime/device_set.h @@ -0,0 +1,64 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_ +#define TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// DeviceSet is a container class for managing the various types of +// devices used by a model. +class DeviceSet { + public: + DeviceSet(); + ~DeviceSet(); + + // Does not take ownership of 'device'. + void AddDevice(Device* device); + + // Set the device designated as the "client". This device + // must also be registered via AddDevice(). + void set_client_device(Device* device) { client_device_ = device; } + + // Returns a pointer to the device designated as the "client". + Device* client_device() const { return client_device_; } + + // Return the list of devices in this set. + const std::vector& devices() const { return devices_; } + + // Given a DeviceNameUtils::ParsedName (which may have some + // wildcards for different components), fills "*devices" with all + // devices in "*this" that match "spec". + void FindMatchingDevices(const DeviceNameUtils::ParsedName& spec, + std::vector* devices) const; + + // Finds the device with the given "fullname". Returns nullptr if + // not found. + Device* FindDeviceByName(const string& fullname) const; + + // Return the list of unique device types in this set, ordered + // with more preferable devices earlier. + std::vector PrioritizedDeviceTypeList() const; + + private: + // Not owned. + std::vector devices_; + + // Fullname -> device* for device in devices_. + std::unordered_map device_by_name_; + + // client_device_ points to an element of devices_ that we consider + // to be the client device (in this local process). + Device* client_device_ = nullptr; + + TF_DISALLOW_COPY_AND_ASSIGN(DeviceSet); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_ diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc new file mode 100644 index 0000000000..1b80a5b697 --- /dev/null +++ b/tensorflow/core/common_runtime/device_set_test.cc @@ -0,0 +1,65 @@ +#include "tensorflow/core/common_runtime/device_set.h" + +#include "tensorflow/core/public/status.h" +#include + +namespace tensorflow { +namespace { + +// Return a fake device with the specified type and name. +static Device* Dev(const char* type, const char* name) { + class FakeDevice : public Device { + public: + explicit FakeDevice(const DeviceAttributes& attr) + : Device(nullptr, attr, nullptr) {} + Status Sync() override { return Status::OK(); } + Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } + }; + DeviceAttributes attr; + attr.set_name(name); + attr.set_device_type(type); + return new FakeDevice(attr); +} + +class DeviceSetTest : public testing::Test { + public: + void AddDevice(const char* type, const char* name) { + Device* d = Dev(type, name); + owned_.emplace_back(d); + devices_.AddDevice(d); + } + + std::vector types() const { + return devices_.PrioritizedDeviceTypeList(); + } + + private: + DeviceSet devices_; + std::vector> owned_; +}; + +TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) { + EXPECT_EQ(std::vector{}, types()); + + AddDevice("CPU", "/job:a/replica:0/task:0/cpu:0"); + EXPECT_EQ(std::vector{DeviceType(DEVICE_CPU)}, types()); + + AddDevice("CPU", "/job:a/replica:0/task:0/cpu:1"); + EXPECT_EQ(std::vector{DeviceType(DEVICE_CPU)}, types()); + + AddDevice("GPU", "/job:a/replica:0/task:0/gpu:0"); + EXPECT_EQ( + (std::vector{DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}), + types()); + + AddDevice("T1", "/job:a/replica:0/task:0/device:T1:0"); + AddDevice("T1", "/job:a/replica:0/task:0/device:T1:1"); + AddDevice("T2", "/job:a/replica:0/task:0/device:T2:0"); + EXPECT_EQ( + (std::vector{DeviceType("T1"), DeviceType("T2"), + DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}), + types()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eigen_thread_pool.h b/tensorflow/core/common_runtime/eigen_thread_pool.h new file mode 100644 index 0000000000..2554f3521b --- /dev/null +++ b/tensorflow/core/common_runtime/eigen_thread_pool.h @@ -0,0 +1,22 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_EIGEN_THREAD_POOL_H_ +#define TENSORFLOW_COMMON_RUNTIME_EIGEN_THREAD_POOL_H_ + +#include "tensorflow/core/lib/core/threadpool.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} + + void Schedule(std::function fn) override { pool_->Schedule(fn); } + + private: + thread::ThreadPool* pool_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_EIGEN_THREAD_POOL_H_ diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc new file mode 100644 index 0000000000..7f2473f93b --- /dev/null +++ b/tensorflow/core/common_runtime/executor.cc @@ -0,0 +1,2118 @@ +#include "tensorflow/core/common_runtime/executor.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_segment.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/edgeset.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" + +namespace tensorflow { + +namespace { + +// 1-D, 0 element tensor. +static const Tensor* const kEmptyTensor = new Tensor; + +bool IsInitializationOp(const Node* node) { + return node->op_def().allows_uninitialized_input(); +} + +// Sets the timeline_label field of *node_stats, using data from *node. +// Returns true iff the node is a transfer node. +// TODO(tucker): merge with the DetailText function in session.cc +// in a common location. +bool SetTimelineLabel(const Node* node, NodeExecStats* node_stats) { + bool is_transfer_node = false; + string memory; + for (auto& all : node_stats->memory()) { + int64 tot = all.total_bytes(); + if (tot >= 0.1 * 1048576.0) { + int64 peak = all.peak_bytes(); + if (peak > 0) { + memory = + strings::StrCat(memory, "[", all.allocator_name(), + strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0, + peak / 1048576.0)); + } else { + memory = strings::StrCat(memory, "[", all.allocator_name(), + strings::Printf(" %.1fMB] ", tot / 1048576.0)); + } + } + } + const NodeDef& def = node->def(); + string text = ""; + if (IsSend(node)) { + string tensor_name; + TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name)); + string recv_device; + TF_CHECK_OK(GetNodeAttr(def, "recv_device", &recv_device)); + text = strings::StrCat(memory, def.name(), " = ", def.op(), "(", + tensor_name, " @", recv_device); + is_transfer_node = true; + } else if (IsRecv(node)) { + string tensor_name; + TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name)); + string send_device; + TF_CHECK_OK(GetNodeAttr(def, "send_device", &send_device)); + text = strings::StrCat(memory, def.name(), " = ", def.op(), "(", + tensor_name, " @", send_device); + is_transfer_node = true; + } else { + text = strings::StrCat( + memory, def.name(), " = ", def.op(), "(", + str_util::Join( + std::vector(def.input().begin(), def.input().end()), + ", "), + ")"); + } + node_stats->set_timeline_label(text); + return is_transfer_node; +} + +// Helper routines for collecting step stats. +namespace nodestats { +inline int64 NowInUsec() { return Env::Default()->NowMicros(); } + +void SetScheduled(NodeExecStats* nt, int64 t) { nt->set_scheduled_micros(t); } + +void SetAllStart(NodeExecStats* nt) { nt->set_all_start_micros(NowInUsec()); } + +void SetOpStart(NodeExecStats* nt) { + DCHECK_NE(nt->all_start_micros(), 0); + nt->set_op_start_rel_micros(NowInUsec() - nt->all_start_micros()); +} + +void SetOpEnd(NodeExecStats* nt) { + DCHECK_NE(nt->all_start_micros(), 0); + nt->set_op_end_rel_micros(NowInUsec() - nt->all_start_micros()); +} + +void SetAllEnd(NodeExecStats* nt) { + DCHECK_NE(nt->all_start_micros(), 0); + nt->set_all_end_rel_micros(NowInUsec() - nt->all_start_micros()); +} + +void SetOutput(NodeExecStats* nt, int slot, AllocationType allocation_type, + const Tensor* v) { + DCHECK(v); + NodeOutput* no = nt->add_output(); + no->set_slot(slot); + no->set_allocation_type(allocation_type); + v->FillDescription(no->mutable_tensor_description()); +} + +void SetMemory(NodeExecStats* nt, OpKernelContext* ctx) { + for (const auto& allocator_pair : ctx->wrapped_allocators()) { + AllocatorMemoryUsed* memory = nt->add_memory(); + // retrieving the sizes from the wrapped allocator removes the + // executor's reference to it, so allocator_pair.second must not + // be dereferenced again after this statement + auto sizes = allocator_pair.second->GetSizesAndUnRef(); + memory->set_allocator_name(allocator_pair.first->Name()); + int tb = sizes.first; + memory->set_total_bytes(tb); + if (allocator_pair.first->TracksAllocationSizes()) { + memory->set_peak_bytes(sizes.second); + } + } +} +} // namespace nodestats + +struct NodeItem { + // A graph node. + const Node* node = nullptr; + + // The kernel for this node. + OpKernel* kernel = nullptr; + + // ExecutorImpl::tensors_[input_start] is the 1st positional input + // for this node. + int input_start = 0; +}; + +// Map from std::pair to attributes. +struct pairhash { + public: + template + std::size_t operator()(const std::pair& x) const { + return std::hash()(x.first) ^ std::hash()(x.second); + } +}; +typedef std::unordered_map, AllocatorAttributes, pairhash> + DevAttrMap; + +typedef gtl::InlinedVector TensorValueVec; +typedef gtl::InlinedVector DeviceContextVec; +typedef gtl::InlinedVector AllocatorAttributeVec; + +class ExecutorImpl : public Executor { + public: + ExecutorImpl(const LocalExecutorParams& p, const Graph* g) + : params_(p), graph_(g) { + CHECK(p.create_kernel != nullptr); + CHECK(p.delete_kernel != nullptr); + } + + ~ExecutorImpl() override { + for (NodeItem& item : nodes_) { + params_.delete_kernel(item.kernel); + } + delete graph_; + } + + Status Initialize(); + + // Infer memory allocation attributes of a node n's output, + // based on its use node dst. Note that dst might not be directly + // connected to n by a single edge, but might be a downstream + // consumer of n's output by reference. *attr is updated with any + // necessary attributes. + Status InferAllocAttr(const Node* n, const Node* dst, + const DeviceNameUtils::ParsedName& local_dev_name, + AllocatorAttributes* attr); + + // Process all Nodes in the current graph, attempting to infer the + // memory allocation attributes to be used wherever they may allocate + // a tensor buffer. + Status SetAllocAttrs(); + + void RunAsync(const Args& args, DoneCallback done) override; + + private: + friend class ExecutorState; + friend class SimpleExecutorState; + + // Owned. + LocalExecutorParams params_; + const Graph* graph_; + std::vector nodes_; // nodes_.size == graph_.num_node_ids(). + int total_tensors_ = 0; // total_tensors_ = sum(nodes_[*].num_inputs()) + + // The number of inputs for each frame in this graph. This is static + // information of the graph. + std::unordered_map frame_input_count_; + + DevAttrMap alloc_attr_; + + TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl); +}; + +Status ExecutorImpl::Initialize() { + const int num_nodes = graph_->num_node_ids(); + nodes_.resize(num_nodes); + + Status s; + total_tensors_ = 0; + + // Preprocess every node in the graph to create an instance of op + // kernel for each node; + for (const Node* n : graph_->nodes()) { + const int id = n->id(); + NodeItem* item = &nodes_[id]; + item->node = n; + item->input_start = total_tensors_; + total_tensors_ += n->num_inputs(); + s = params_.create_kernel(n->def(), &item->kernel); + if (!s.ok()) { + s = AttachDef(s, n->def()); + LOG(ERROR) << "Executor failed to create kernel. " << s; + break; + } + CHECK(item->kernel); + + // Initialize static information about the frames in the graph. + if (IsEnter(n)) { + string frame_name; + s = GetNodeAttr(n->def(), "frame_name", &frame_name); + if (!s.ok()) return s; + ++frame_input_count_[frame_name]; + } + } + if (params_.has_control_flow) { + VLOG(2) << "Graph has control flow."; + } + if (!s.ok()) return s; + return SetAllocAttrs(); +} + +Status ExecutorImpl::SetAllocAttrs() { + Status s; + Device* device = params_.device; + DeviceNameUtils::ParsedName local_dev_name = device->parsed_name(); + + for (const Node* n : graph_->nodes()) { + // Examine the out edges of each node looking for special use + // cases that may affect memory allocation attributes. + for (auto e : n->out_edges()) { + AllocatorAttributes attr; + s = InferAllocAttr(n, e->dst(), local_dev_name, &attr); + if (!s.ok()) return s; + if (attr.value != 0) { + VLOG(2) << "node " << n->name() << " gets attr " << attr.value + << " for output " << e->src_output(); + alloc_attr_[std::make_pair(n->id(), e->src_output())].Merge(attr); + } else { + VLOG(2) << "default output attr for node " << n->name() << " output " + << e->src_output(); + } + } + } + return s; +} + +Status ExecutorImpl::InferAllocAttr( + const Node* n, const Node* dst, + const DeviceNameUtils::ParsedName& local_dev_name, + AllocatorAttributes* attr) { + Status s; + if (IsSend(dst)) { + string dst_name; + s = GetNodeAttr(dst->def(), "recv_device", &dst_name); + if (!s.ok()) return s; + DeviceNameUtils::ParsedName parsed_dst_name; + if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) { + s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ", + n->name()); + return s; + } + if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) { + // Value is going to be the source of an RPC. + attr->set_nic_compatible(true); + VLOG(2) << "node " << n->name() << " is the source of an RPC out"; + } else if (local_dev_name.type == "CPU" && parsed_dst_name.type == "GPU") { + // Value is going to be the source of a local DMA from CPU to GPU. + attr->set_gpu_compatible(true); + VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy"; + } else { + VLOG(2) << "default alloc case local type " << local_dev_name.type + << " remote type " << parsed_dst_name.type; + } + } else if (dst->type_string() == "ToFloat") { + for (auto e : dst->out_edges()) { + s = InferAllocAttr(n, e->dst(), local_dev_name, attr); + if (!s.ok()) return s; + } + } + return s; +} + +// The state associated with one invokation of ExecutorImpl::Run. +// ExecutorState dispatches nodes when they become ready and keeps +// track of how many predecessors of a node have not done (pending_). +class ExecutorState { + public: + ExecutorState(const Executor::Args& args, ExecutorImpl* impl); + ~ExecutorState(); + + void RunAsync(Executor::DoneCallback done); + + private: + typedef ExecutorState ME; + + // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value). + // TODO(yuanbyu): A better way to do "has_value"? + struct Entry { + Tensor val = *kEmptyTensor; // A tensor value. + Tensor* ref = nullptr; // A tensor reference. + mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr. + bool has_value = false; // Whether the value exists + + // Every entry carries an optional DeviceContext containing + // Device-specific information about how the Tensor was produced. + DeviceContext* device_context = nullptr; + + // The attributes of the allocator that creates the tensor. + AllocatorAttributes alloc_attr; + }; + + // Contains a map from node id to the DeviceContext object that was + // assigned by the device at the beginning of a step. + DeviceContextMap device_context_map_; + + struct IterationState { + // The state of an iteration. + + // The pending count for each graph node. One copy per iteration. + // Iteration i can be garbage collected when it is done. + // TODO(yuanbyu): This vector currently has size of the number of nodes + // in this partition. This is not efficient if the subgraph for the frame + // is only a small subset of the partition. We should make the vector + // size to be only the size of the frame subgraph. + std::vector* pending_count; + + // The dead input count for each graph node. One copy per iteration. + std::vector* dead_count; + + // One copy per iteration. For iteration k, i-th node's j-th input is in + // input_tensors[k][impl_->nodes[i].input_start + j]. An entry is either + // a tensor pointer (pass-by-reference) or a tensor (pass-by-value). + // + // NOTE: No need to protect input_tensors[i] by any locks because it + // is resized once. Each element of tensors_ is written once by the + // source node of an edge and is cleared by the destination of the same + // edge. The latter node is never run concurrently with the former node. + std::vector* input_tensors; + + // The number of outstanding ops for each iteration. + int outstanding_ops; + + // The number of outstanding frames for each iteration. + int outstanding_frame_count; + + ~IterationState() { + delete pending_count; + delete dead_count; + delete input_tensors; + } + }; + + struct FrameState { + // A new frame is created for each loop. Execution starts at iteration 0. + // When a value at iteration 0 passes through a NextIteration node, + // iteration 1 is created and starts running. Note that iteration 0 may + // still be running so multiple iterations may run in parallel. The + // frame maintains the state of iterations in several data structures + // such as pending_count and input_tensors. When iteration 0 completes, + // we garbage collect the state of iteration 0. + // + // A frame instance is considered "done" and can be garbage collected + // if all its inputs have entered and all its iterations are "done". + // + // A frame manages the live iterations of an iterative computation. + // Iteration i is considered "done" when there are no outstanding ops, + // frames at iteration i are done, all recvs for this iteration are + // completed, and iteration i-1 is done. For iteration 0, we instead + // wait for there to be no more pending inputs of the frame. + // + // Frames and iterations are garbage collected once they are done. + // The state we need to keep around is highly dependent on the + // parallelism enabled by the scheduler. We may want to have the + // scheduler dynamically control the outstanding number of live + // parallel frames and iterations. To reduce the state space, the + // scheduler might want to schedule ops in inner frames first and + // lower iterations first. + // + // This frame state is mostly initialized lazily on demand so we + // don't introduce unnecessary overhead. + + // The name of this frame, which is the concatenation of its parent + // frame name, the iteration of the parent frame when this frame was + // created, and the value of the attr 'frame_name'. + string frame_name; + + // The unique id for this frame. Generated by fingerprinting + // frame_name. + uint64 frame_id; + + // The iteration id of its parent frame when this frame is created. + // -1 if there is no parent frame. The frame_name/parent_iter pair + // uniquely identifies this FrameState. + int64 parent_iter = -1; + + // The FrameState of its parent frame. + FrameState* parent_frame = nullptr; + + // The highest iteration number we have reached so far in this frame. + int64 iteration_count = 0; + + // The number of inputs this frame is still waiting. + int num_pending_inputs = 0; + + // The number of outstanding iterations. + int num_outstanding_iterations = 0; + + // The maximum allowed number of parallel iterations. + int max_parallel_iterations = 1; + + // The iteration states of this frame. + std::vector iterations; + + // The NextIteration nodes to enter a new iteration. If the number of + // outstanding iterations reaches the limit, we will defer the start of + // the next iteration until the number of outstanding iterations falls + // below the limit. + std::vector> next_iter_roots; + + // The values of the loop invariants for this loop. They are added into + // this list as they "enter" the frame. When a loop invariant enters, + // we make it available to all active iterations. When the frame starts + // a new iteration, we make all the current loop invariants available + // to the new iteration. + std::vector> inv_values; + + // The list of dead exit nodes for the current highest iteration. We + // will only "execute" the dead exits of the final iteration. + std::vector dead_exits; + + IterationState* GetIteration(int64 iter) { + int index = iter % iterations.size(); + return iterations[index]; + } + + void SetIteration(int64 iter, IterationState* state) { + int index = iter % iterations.size(); + iterations[index] = state; + } + + ~FrameState() { + for (size_t i = 0; i < iterations.size(); ++i) { + delete iterations[i]; + iterations[i] = nullptr; + } + } + }; + + // A tagged node: . + struct TaggedNode { + const Node* node = nullptr; + FrameState* input_frame = nullptr; + int64 input_iter = -1; + bool is_dead = false; + + TaggedNode(const Node* t_node, FrameState* in_frame, int64 in_iter, + bool dead) { + node = t_node; + input_frame = in_frame; + input_iter = in_iter; + is_dead = dead; + } + }; + + typedef gtl::InlinedVector TaggedNodeSeq; + typedef gtl::InlinedVector EntryVector; + + // Not owned. + Rendezvous* rendezvous_; + StepStatsCollector* stats_collector_; + // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper instead of a + // pointer? (avoids having to delete). + checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_; + FunctionCallFrame* call_frame_; + const ExecutorImpl* impl_; + CancellationManager* cancellation_manager_; + Executor::Args::Runner runner_; + + // Owned. + + // Step-local resource manager. + ResourceMgr step_resource_manager_; + + // The root frame in which the execution of this step is started. + FrameState* root_frame_; + + // Invoked when the execution finishes. + Executor::DoneCallback done_cb_; + + std::atomic_int_fast32_t num_outstanding_ops_; + + mutex mu_; + Status status_ GUARDED_BY(mu_); + + // Mapping from frame name to outstanding frames. A new frame is created + // at some iteration of an active frame. So the unique key for the new + // child frame is composed of the name of the parent frame, the iteration + // number at which the parent frame is creating the new frame, and the + // name of the new frame from nodedef. + std::unordered_map outstanding_frames_ GUARDED_BY(mu_); + + // The unique name of a frame. + inline string MakeFrameName(FrameState* frame, int64 iter_id, string name) { + return strings::StrCat(frame->frame_name, ";", iter_id, ";", name); + } + + // Initialize the pending count for a graph. + static void InitializePending(const Graph* graph, std::vector* pending); + + // Find an existing or create a new child frame in the frame 'frame' at + // iteration 'iter'. + void FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node, + FrameState** child) EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Increments the iteration id. If this is a new iteration, initialize it. + void IncrementIteration(FrameState* frame, TaggedNodeSeq* ready) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns true if the computation in the frame is completed. + bool IsFrameDone(FrameState* frame) EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns true if the iteration of the frame is completed. + bool IsIterationDone(FrameState* frame, int64 iter) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Get the output frame/iter of a node. Create new frame/iteration if + // needed. If there are dead roots for the new iteration, we need to + // "execute" them so ad them to the ready queue. Returns true if + // we need to check for the completion of output frame/iter. + bool SetOutputFrameIter(const TaggedNode& tagged_node, + const EntryVector& outputs, FrameState** frame, + int64* iter, TaggedNodeSeq* ready) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Cleanup frames and iterations + void CleanupFramesIterations(FrameState* frame, int64 iter, + TaggedNodeSeq* ready) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Activate all the deferred NextIteration nodes in a new iteration. + void ActivateNexts(FrameState* frame, int64 iter, TaggedNodeSeq* ready) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Activate all the current loop invariants in a new iteration. + void ActivateLoopInvs(FrameState* frame, int64 iter, TaggedNodeSeq* ready) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Add a new loop invariant and make it available to all active iterations. + void AddLoopInv(FrameState* frame, const Node* node, const Entry& value, + TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Activate the successors of a node. + void ActivateNode(const Node* node, const bool is_dead, FrameState* frame, + int64 iter, const EntryVector& outputs, + TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Process a ready node in current thread. + void Process(TaggedNode node, int64 scheduled_usec); + + // Before invoking item->kernel, fills in its "inputs". + Status PrepareInputs(const NodeItem& item, Entry* first_input, + TensorValueVec* inputs, + DeviceContextVec* input_device_contexts, + AllocatorAttributeVec* input_alloc_attrs, + bool* is_input_dead); + + // After item->kernel computation is done, processes its outputs. + Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, + EntryVector* outputs, NodeExecStats* stats); + + // After processing the outputs, propagates the outputs to their dsts. + void PropagateOutputs(const TaggedNode& tagged_node, + const EntryVector& outputs, TaggedNodeSeq* ready); + + // "node" just finishes. Takes ownership of "stats". Returns true if + // execution has completed. + bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, + NodeExecStats* stats, std::deque* inline_ready); + + // Call Process() on all nodes in 'inline_ready'. + void ProcessInline(const std::deque& inline_ready); + + // Schedule all the expensive nodes in 'ready', and put all the inexpensive + // nodes in 'ready' into 'inline_ready'. + void ScheduleReady(const TaggedNodeSeq& ready, + std::deque* inline_ready); + + // One thread of control finishes. + void Finish(); +}; + +ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) + : rendezvous_(args.rendezvous), + stats_collector_(args.stats_collector), + slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), + call_frame_(args.call_frame), + impl_(impl), + cancellation_manager_(args.cancellation_manager), + runner_(args.runner), + num_outstanding_ops_(0) { + // We start the entire execution in iteration 0 of the root frame + // so let us create the root frame and the state for iteration 0. + // Initialize the frame. + root_frame_ = new FrameState; + root_frame_->frame_name = "_root"; // assume to be unique + root_frame_->frame_id = 0; // must be 0 + root_frame_->num_pending_inputs = 0; + root_frame_->num_outstanding_iterations = 1; + root_frame_->max_parallel_iterations = 1; // enough for root frame + root_frame_->iterations.resize(root_frame_->max_parallel_iterations); + + VLOG(2) << "Create frame: " << root_frame_->frame_name; + + // Initialize the iteration. + IterationState* iter_state = new IterationState; + root_frame_->iterations[0] = iter_state; + iter_state->outstanding_ops = 0; + iter_state->outstanding_frame_count = 0; + iter_state->pending_count = new std::vector; + iter_state->dead_count = new std::vector(impl->graph_->num_node_ids()); + iter_state->input_tensors = new std::vector(impl_->total_tensors_); + + // Initialize the executor state. + outstanding_frames_.insert({root_frame_->frame_name, root_frame_}); +} + +ExecutorState::~ExecutorState() { + for (auto name_frame : outstanding_frames_) { + delete name_frame.second; + } + + for (auto it : device_context_map_) { + it.second->Unref(); + } + + delete slice_reader_cache_; +} + +void ExecutorState::InitializePending(const Graph* graph, + std::vector* pending) { + pending->resize(graph->num_node_ids()); + for (const Node* n : graph->nodes()) { + const int id = n->id(); + const int num_in_edges = n->in_edges().size(); + if (IsMerge(n)) { + // merge waits all control inputs so we initialize the pending + // count to be the number of control edges. + int32 num_control_edges = 0; + for (const Edge* edge : n->in_edges()) { + if (edge->IsControlEdge()) { + num_control_edges++; + } + } + // Use bit 0 to indicate if there is a ready live data input. + (*pending)[id] = num_control_edges << 1; + } else { + (*pending)[id] = num_in_edges; + } + } +} + +void ExecutorState::RunAsync(Executor::DoneCallback done) { + const Graph* graph = impl_->graph_; + TaggedNodeSeq ready; + + { + // Initialize the executor state. We grab the mutex here just to + // keep the thread safety analysis happy. + mutex_lock l(mu_); + std::vector* pending = root_frame_->iterations[0]->pending_count; + InitializePending(graph, pending); + } + + // Ask the device to fill in the device context map. + Device* device = impl_->params_.device; + device->FillContextMap(graph, &device_context_map_); + + // Initialize the ready queue. + for (const Node* n : graph->nodes()) { + const int num_in_edges = n->in_edges().size(); + if (num_in_edges == 0) { + ready.push_back(TaggedNode{n, root_frame_, 0, false}); + } + } + if (ready.empty()) { + done(Status::OK()); + } else { + num_outstanding_ops_ = ready.size(); + root_frame_->iterations[0]->outstanding_ops = ready.size(); + done_cb_ = done; + // Schedule to run all the ready ops in thread pool. + ScheduleReady(ready, nullptr); + } +} + +namespace { + +// This function is provided for use by OpKernelContext when allocating +// the index'th output of node. It provides access to the +// AllocatorAttributes computed during initialization to determine in +// which memory region the tensor should be allocated. +AllocatorAttributes OutputAttributes(const DevAttrMap* attr_map, + const Node* node, + const OpKernel* op_kernel, int index) { + DCHECK_GE(index, 0); + + AllocatorAttributes attr; + int nid = node->id(); + const auto& iter = attr_map->find(std::make_pair(nid, index)); + if (iter != attr_map->end()) { + attr = iter->second; + VLOG(2) << "nondefault attr " << attr.value << " for node " << node->name() + << " output " << index; + } else { + VLOG(2) << "default attr for node " << node->name() << " output " << index; + } + + DCHECK_LT(index, op_kernel->output_memory_types().size()); + bool on_host = op_kernel->output_memory_types()[index] == HOST_MEMORY; + attr.set_on_host(on_host); + return attr; +} + +// Helpers to make a copy of 'p' and makes a copy of the input type +// vector and the device context vector. +// +// NOTE: We need to make a copy of p.input for asynchronous kernel +// because OpKernelContext methods like input_type(i) needs the param +// points to valid input type vector. It's not an issue for sync +// kernels because the type vector is kept on the stack. +OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) { + OpKernelContext::Params* ret = new OpKernelContext::Params; + *ret = p; + ret->inputs = new TensorValueVec(*p.inputs); + ret->input_device_contexts = new DeviceContextVec(*p.input_device_contexts); + ret->input_alloc_attrs = new AllocatorAttributeVec(*p.input_alloc_attrs); + return ret; +} + +// Helpers to delete 'p' and copies made by CopyParams. +void DeleteParams(OpKernelContext::Params* p) { + delete p->inputs; + delete p->input_device_contexts; + delete p->input_alloc_attrs; + delete p; +} + +} // namespace + +void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { + const std::vector& nodes = impl_->nodes_; + TaggedNodeSeq ready; + std::deque inline_ready; + + // Parameters passed to OpKernel::Compute. + TensorValueVec inputs; + DeviceContextVec input_device_contexts; + AllocatorAttributeVec input_alloc_attrs; + + OpKernelContext::Params params; + Device* device = impl_->params_.device; + params.device = device; + // track allocations if and only if we are collecting statistics + params.track_allocations = (stats_collector_ != nullptr); + params.rendezvous = rendezvous_; + params.cancellation_manager = cancellation_manager_; + params.call_frame = call_frame_; + params.function_library = impl_->params_.function_library; + params.resource_manager = device->resource_manager(); + params.step_resource_manager = &step_resource_manager_; + params.slice_reader_cache = slice_reader_cache_; + params.inputs = &inputs; + params.input_device_contexts = &input_device_contexts; + params.input_alloc_attrs = &input_alloc_attrs; + + Status s; + NodeExecStats* stats = nullptr; + EntryVector outputs; + bool completed = false; + inline_ready.push_back(tagged_node); + while (!inline_ready.empty()) { + tagged_node = inline_ready.front(); + inline_ready.pop_front(); + const Node* node = tagged_node.node; + FrameState* input_frame = tagged_node.input_frame; + int64 input_iter = tagged_node.input_iter; + const int id = node->id(); + const NodeItem& item = nodes[id]; + + // Set the device_context for this node id, if it exists. + auto dc_it = device_context_map_.find(id); + if (dc_it != device_context_map_.end()) { + params.op_device_context = dc_it->second; + } + + if (stats_collector_) { + stats = new NodeExecStats; + stats->set_node_name(node->name()); + nodestats::SetScheduled(stats, scheduled_usec); + nodestats::SetAllStart(stats); + } + + VLOG(1) << "Process node: " << id << " " << SummarizeNodeDef(node->def()); + + std::vector* input_tensors; + { + // Need the lock because the iterations vector could be resized by + // another thread. + mutex_lock l(mu_); + input_tensors = input_frame->GetIteration(input_iter)->input_tensors; + } + Entry* first_input = input_tensors->data() + item.input_start; + outputs.clear(); + outputs.resize(node->num_outputs()); + + // Only execute this node if it is not dead or it is a send/recv + // transfer node. For transfer nodes, we need to propagate the "dead" + // bit even when the node is dead. + AsyncOpKernel* async = nullptr; + if (!tagged_node.is_dead || IsTransferNode(node)) { + // Prepares inputs. + bool is_input_dead = false; + s = PrepareInputs(item, first_input, &inputs, &input_device_contexts, + &input_alloc_attrs, &is_input_dead); + if (!s.ok()) { + // Continue to process the nodes in 'inline_ready'. + completed = NodeDone(s, item.node, ready, stats, &inline_ready); + continue; + } + + // Set up compute params. + OpKernel* op_kernel = item.kernel; + params.op_kernel = op_kernel; + params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter); + params.is_input_dead = is_input_dead; + params.output_alloc_attr = [this, node, op_kernel](int index) { + return OutputAttributes(&impl_->alloc_attr_, node, op_kernel, index); + }; + + async = op_kernel->AsAsync(); + if (async) { + // Asynchronous computes. + auto pcopy = CopyParams(params); + auto ctx = new OpKernelContext(*pcopy); + auto done = [this, tagged_node, item, first_input, ctx, stats, + pcopy]() { + VLOG(2) << this << " Async kernel done: " + << SummarizeNodeDef(item.node->def()); + if (stats_collector_) nodestats::SetOpEnd(stats); + EntryVector outputs; + Status s = ProcessOutputs(item, ctx, &outputs, stats); + if (stats_collector_) nodestats::SetMemory(stats, ctx); + // Clears inputs. + int num_inputs = tagged_node.node->num_inputs(); + for (int i = 0; i < num_inputs; ++i) { + (first_input + i)->val = *kEmptyTensor; + } + TaggedNodeSeq ready; + if (s.ok()) { + PropagateOutputs(tagged_node, outputs, &ready); + } + // Schedule to run all the ready ops in thread pool. + bool completed = NodeDone(s, item.node, ready, stats, nullptr); + delete ctx; + DeleteParams(pcopy); + if (completed) Finish(); + }; + if (stats_collector_) nodestats::SetOpStart(stats); + device->ComputeAsync(async, ctx, done); + } else { + // Synchronous computes. + OpKernelContext ctx(params); + if (stats_collector_) nodestats::SetOpStart(stats); + device->Compute(CHECK_NOTNULL(op_kernel), &ctx); + if (stats_collector_) nodestats::SetOpEnd(stats); + + // Processes outputs. + s = ProcessOutputs(item, &ctx, &outputs, stats); + if (stats_collector_) nodestats::SetMemory(stats, &ctx); + } + } + + if (!async) { + // Clears inputs. + int num_inputs = node->num_inputs(); + for (int i = 0; i < num_inputs; ++i) { + (first_input + i)->val = *kEmptyTensor; + } + // Propagates outputs. + if (s.ok()) { + PropagateOutputs(tagged_node, outputs, &ready); + } + if (stats_collector_) { + scheduled_usec = nodestats::NowInUsec(); + } + // Postprocess. + completed = NodeDone(s, item.node, ready, stats, &inline_ready); + } + } // while !inline_ready.empty() + + // This thread of computation is done if completed = true. + if (completed) Finish(); +} + +Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, + TensorValueVec* inputs, + DeviceContextVec* input_device_contexts, + AllocatorAttributeVec* input_alloc_attrs, + bool* is_input_dead) { + const Node* node = item.node; + + inputs->clear(); + inputs->resize(node->num_inputs()); + input_device_contexts->clear(); + input_device_contexts->resize(node->num_inputs()); + input_alloc_attrs->clear(); + input_alloc_attrs->resize(node->num_inputs()); + + *is_input_dead = false; + + bool is_merge = IsMerge(node); + for (int i = 0; i < node->num_inputs(); ++i) { + const bool expect_ref = IsRefType(node->input_type(i)); + Entry* entry = first_input + i; + (*input_device_contexts)[i] = entry->device_context; + (*input_alloc_attrs)[i] = entry->alloc_attr; + + // i-th input. + TensorValue* inp = &(*inputs)[i]; + + // Only merge and transfer nodes can have no-value inputs. + if (!entry->has_value) { + if (!is_merge) { + DCHECK(IsTransferNode(node)); + inp->tensor = &entry->val; + *is_input_dead = true; + } + continue; + } + if (entry->ref == nullptr) { + if (expect_ref) { + return AttachDef( + errors::InvalidArgument(i, "-th input expects a ref type"), + item.kernel->def()); + } + inp->tensor = &entry->val; + } else { + if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) { + return AttachDef( + errors::FailedPrecondition("Attempting to use uninitialized value ", + item.kernel->def().input(i)), + item.kernel->def()); + } + if (expect_ref) { + inp->mutex_if_ref = entry->ref_mu; + inp->tensor = entry->ref; + } else { + // Automatically deref the tensor ref when the op expects a + // tensor but is given a ref to a tensor. Need to deref it + // under the mutex. + { + mutex_lock l(*(entry->ref_mu)); + entry->val = *entry->ref; + } + inp->tensor = &entry->val; + } + } + } + return Status::OK(); +} + +Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, + EntryVector* outputs, + NodeExecStats* stats) { + const Node* node = item.node; + outputs->clear(); + outputs->resize(node->num_outputs()); + + Status s = ctx->status(); + if (!s.ok()) { + s = AttachDef(s, item.kernel->def()); + LOG(WARNING) << this << " Compute status: " << s; + return s; + } + + // Get the device_context for this node id, if it exists. + DeviceContext* device_context = nullptr; + auto dc_it = device_context_map_.find(node->id()); + if (dc_it != device_context_map_.end()) { + device_context = dc_it->second; + } + + for (int i = 0; i < node->num_outputs(); ++i) { + TensorValue val = ctx->release_output(i); + // Only Switch and Recv nodes can generate new dead outputs + if (*ctx->is_output_dead() || val.tensor == nullptr) { + DCHECK(IsSwitch(node) || IsRecv(node)); + } else { + Entry* out = &((*outputs)[i]); + out->has_value = true; + + // Set the device context of the output entry. + out->device_context = device_context; + + // Set the allocator attributes of the output entry. + out->alloc_attr = ctx->output_alloc_attr(i); + + // Sanity check of output tensor types. + DataType dtype = val->dtype(); + if (val.is_ref()) dtype = MakeRefType(dtype); + if (dtype == node->output_type(i)) { + if (val.is_ref()) { + out->ref = val.tensor; + out->ref_mu = val.mutex_if_ref; + } else { + out->val = *val.tensor; + } + if (stats_collector_ && val.tensor->IsInitialized()) { + nodestats::SetOutput(stats, i, ctx->output_allocation_type(i), + val.tensor); + } + } else { + s.Update(errors::Internal("Output ", i, " of type ", + DataTypeString(dtype), + " does not match declared output type ", + DataTypeString(node->output_type(i)), + " for node ", SummarizeNodeDef(node->def()))); + } + } + if (!val.is_ref()) { + // If OpKernelContext returns outputs via pass-by-value, we + // don't need this trouble. + delete val.tensor; + } + } + return s; +} + +void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, + const EntryVector& outputs, + TaggedNodeSeq* ready) { + FrameState* input_frame = tagged_node.input_frame; + int64 input_iter = tagged_node.input_iter; + + // Propagates outputs along out edges, and puts newly ready nodes + // into the ready queue. + ready->clear(); + + { + FrameState* output_frame = input_frame; + int64 output_iter = input_iter; + + mutex_lock l(mu_); + // Sets the output_frame and output_iter of node. + bool maybe_completed = SetOutputFrameIter( + tagged_node, outputs, &output_frame, &output_iter, ready); + if (output_frame != nullptr) { + // Continue to process the out nodes: + ActivateNode(tagged_node.node, tagged_node.is_dead, output_frame, + output_iter, outputs, ready); + } + + // At this point, this node is completely done. + input_frame->GetIteration(input_iter)->outstanding_ops--; + CleanupFramesIterations(input_frame, input_iter, ready); + + // The execution of a node such as Enter may cause the completion of + // output_frame:output_iter, so perform cleanup if output_frame:output_iter + // is indeed completed. + if (maybe_completed) { + CleanupFramesIterations(output_frame, output_iter, ready); + } + } +} + +void ExecutorState::ActivateNode(const Node* node, const bool is_dead, + FrameState* output_frame, int64 output_iter, + const EntryVector& outputs, + TaggedNodeSeq* ready) { + const std::vector& nodes = impl_->nodes_; + IterationState* output_iter_state = output_frame->GetIteration(output_iter); + std::vector* pending = output_iter_state->pending_count; + std::vector* dead_count = output_iter_state->dead_count; + for (const Edge* e : node->out_edges()) { + const Node* dst_node = e->dst(); + const int dst_id = dst_node->id(); + const int src_slot = e->src_output(); + + bool dst_dead = false; + bool dst_ready = false; + bool dst_need_input = !e->IsControlEdge(); + if (IsMerge(dst_node)) { + // A merge node is ready if a) all control edges are enabled and a + // live data input becomes available, or b) all control edges are + // enabled and all data inputs are dead. + if (e->IsControlEdge()) { + (*pending)[dst_id] -= 2; + int count = (*pending)[dst_id]; + dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs()); + dst_ready = (count == 1) || ((count == 0) && dst_dead); + } else { + if (outputs[src_slot].has_value) { + // This is a live data input. + int count = (*pending)[dst_id]; + (*pending)[dst_id] |= 0x1; + dst_ready = (count == 0); + } else { + // This is a dead data input. + ++(*dead_count)[dst_id]; + dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs()); + dst_ready = ((*pending)[dst_id] == 0) && dst_dead; + } + // This input for dst is not needed if !dst_ready. We suppress the + // propagation to make the thread safety analysis happy. + dst_need_input = dst_ready; + } + } else { + // A non-merge node is ready if all its inputs are ready. We wait + // for all inputs to come in even if we know the node is dead. This + // ensures that all input tensors get cleaned up. + if (is_dead || (!e->IsControlEdge() && !outputs[src_slot].has_value)) { + ++(*dead_count)[dst_id]; + } + dst_dead = (*dead_count)[dst_id] > 0; + dst_ready = (--(*pending)[dst_id] == 0); + } + + if (dst_need_input) { + const NodeItem& dst_item = nodes[dst_id]; + const int dst_slot = e->dst_input(); + std::vector* input_tensors = output_iter_state->input_tensors; + int dst_loc = dst_item.input_start + dst_slot; + (*input_tensors)[dst_loc] = outputs[src_slot]; + } + + // Add dst to the ready queue if it's ready + if (dst_ready) { + dst_dead = dst_dead && !IsControlTrigger(dst_node); + ready->push_back( + TaggedNode(dst_node, output_frame, output_iter, dst_dead)); + output_iter_state->outstanding_ops++; + } + } +} + +void ExecutorState::ActivateNexts(FrameState* frame, int64 iter, + TaggedNodeSeq* ready) { + // Propagate the deferred NextIteration nodes to the new iteration. + for (auto& node_entry : frame->next_iter_roots) { + const Node* node = node_entry.first; + const Entry& entry = node_entry.second; + const bool is_dead = !entry.has_value; + ActivateNode(node, is_dead, frame, iter, {entry}, ready); + } + frame->next_iter_roots.clear(); +} + +void ExecutorState::ActivateLoopInvs(FrameState* frame, int64 iter, + TaggedNodeSeq* ready) { + // Propagate loop invariants to the new iteration. + for (auto& node_entry : frame->inv_values) { + const Node* node = node_entry.first; + const Entry& entry = node_entry.second; + const bool is_dead = !entry.has_value; + ActivateNode(node, is_dead, frame, iter, {entry}, ready); + } +} + +void ExecutorState::AddLoopInv(FrameState* frame, const Node* node, + const Entry& entry, TaggedNodeSeq* ready) { + // Store this value. + frame->inv_values.push_back({node, entry}); + + // Make this value available to all iterations. + bool is_dead = !entry.has_value; + for (int i = 1; i <= frame->iteration_count; ++i) { + ActivateNode(node, is_dead, frame, i, {entry}, ready); + } +} + +bool ExecutorState::NodeDone(const Status& s, const Node* node, + const TaggedNodeSeq& ready, NodeExecStats* stats, + std::deque* inline_ready) { + if (stats_collector_) { + nodestats::SetAllEnd(stats); + if (!SetTimelineLabel(node, stats)) { + // Only record non-transfer nodes. + stats_collector_->Save(impl_->params_.device->name(), stats); + } else { + delete stats; + } + } + + Rendezvous* captured_rendezvous = nullptr; // Will be set on error. + if (!s.ok()) { + // Some error happened. This thread of computation is done. + mutex_lock l(mu_); + if (status_.ok()) { + captured_rendezvous = rendezvous_; + if (captured_rendezvous) captured_rendezvous->Ref(); + status_ = s; + } + } + if (captured_rendezvous) { + // If we captured the rendezvous_ pointer, we are in an error condition. + // Use captured_rendezvous, in case "this" is deleted by another thread. + TRACEPRINTF("StartAbort: %s", s.ToString().c_str()); + captured_rendezvous->StartAbort(s); + captured_rendezvous->Unref(); + } + + bool completed = false; + int ready_size = ready.size(); + if (ready_size == 0 || !s.ok()) { + completed = (num_outstanding_ops_.fetch_sub(1) == 1); + } else if (ready_size > 1) { + num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed); + } + + // Schedule the ready nodes in 'ready'. + if (s.ok()) { + ScheduleReady(ready, inline_ready); + } + return completed; +} + +void ExecutorState::ProcessInline(const std::deque& inline_ready) { + if (inline_ready.empty()) return; + int64 scheduled_usec = 0; + if (stats_collector_) { + scheduled_usec = nodestats::NowInUsec(); + } + for (auto& tagged_node : inline_ready) { + Process(tagged_node, scheduled_usec); + } +} + +void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready, + std::deque* inline_ready) { + if (ready.empty()) return; + + int64 scheduled_usec = 0; + if (stats_collector_) { + scheduled_usec = nodestats::NowInUsec(); + } + if (inline_ready == nullptr) { + // Schedule to run all the ready ops in thread pool. + for (auto& tagged_node : ready) { + runner_(std::bind(&ME::Process, this, tagged_node, scheduled_usec)); + } + return; + } + const std::vector& nodes = impl_->nodes_; + const TaggedNode* curr_expensive_node = nullptr; + for (auto& tagged_node : ready) { + const NodeItem& item = nodes[tagged_node.node->id()]; + if (tagged_node.is_dead || !item.kernel->IsExpensive()) { + // Inline this inexpensive node. + inline_ready->push_back(tagged_node); + } else { + if (curr_expensive_node) { + // Dispatch to another thread since there is plenty of work to + // do for this thread. + runner_(std::bind(&ME::Process, this, *curr_expensive_node, + scheduled_usec)); + } + curr_expensive_node = &tagged_node; + } + } + if (curr_expensive_node) { + if (inline_ready->empty()) { + // Tail recursion optimization + inline_ready->push_back(*curr_expensive_node); + } else { + // There are inline nodes to run already. We dispatch this expensive + // node to other thread. + runner_( + std::bind(&ME::Process, this, *curr_expensive_node, scheduled_usec)); + } + } +} + +void ExecutorState::Finish() { + mu_.lock(); + auto status = status_; + auto done_cb = done_cb_; + auto runner = runner_; + mu_.unlock(); + delete this; + CHECK(done_cb != nullptr); + runner([done_cb, status]() { done_cb(status); }); +} + +bool ExecutorState::IsFrameDone(FrameState* frame) { + return (frame->num_pending_inputs == 0 && + frame->num_outstanding_iterations == 0); +} + +bool ExecutorState::IsIterationDone(FrameState* frame, int64 iter) { + IterationState* iter_state = frame->GetIteration(iter); + if (iter_state->outstanding_ops == 0 && + iter_state->outstanding_frame_count == 0) { + if (iter == 0) { + // The enclosing frame has no pending input. + return frame->num_pending_inputs == 0; + } else { + // The preceding iteration is deleted (and therefore done). + return (frame->GetIteration(iter - 1) == nullptr); + } + } + return false; +} + +void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, + const Node* node, + FrameState** child) { + // Get the child frame name. + string enter_name; + Status s = GetNodeAttr(node->def(), "frame_name", &enter_name); + CHECK(s.ok()) << s; + const string child_name = MakeFrameName(frame, iter, enter_name); + + auto it = outstanding_frames_.find(child_name); + if (it != outstanding_frames_.end()) { + *child = it->second; + } else { + // Need to create a new frame instance. + VLOG(2) << "Create frame: " << child_name; + + FrameState* temp = new FrameState; + temp->frame_name = child_name; + temp->frame_id = Hash64(child_name); + temp->parent_frame = frame; + temp->parent_iter = iter; + s = GetNodeAttr(node->def(), "parallel_iterations", + &temp->max_parallel_iterations); + CHECK(s.ok()) << s; + // 'iterations' is a fixed-length circular buffer. + temp->iterations.resize(temp->max_parallel_iterations + 1); + IterationState* iter_state = new IterationState; + temp->iterations[0] = iter_state; + + iter_state->outstanding_ops = 0; + iter_state->outstanding_frame_count = 0; + iter_state->pending_count = new std::vector; + InitializePending(impl_->graph_, iter_state->pending_count); + iter_state->dead_count = + new std::vector(impl_->graph_->num_node_ids()); + iter_state->input_tensors = new std::vector(impl_->total_tensors_); + + auto frame_pending = impl_->frame_input_count_.find(enter_name); + DCHECK(frame_pending != impl_->frame_input_count_.end()); + temp->num_pending_inputs = frame_pending->second; + temp->num_outstanding_iterations = 1; + *child = temp; + + frame->GetIteration(iter)->outstanding_frame_count++; + outstanding_frames_[child_name] = temp; + } +} + +void ExecutorState::IncrementIteration(FrameState* frame, + TaggedNodeSeq* ready) { + frame->iteration_count++; + int64 next_iter = frame->iteration_count; + + VLOG(2) << "Create iteration: [" << frame->frame_name << ", " << next_iter + << "]"; + + IterationState* iter_state = new IterationState; + frame->SetIteration(next_iter, iter_state); + frame->num_outstanding_iterations++; + frame->dead_exits.clear(); + + iter_state->outstanding_ops = 0; + iter_state->outstanding_frame_count = 0; + iter_state->pending_count = new std::vector; + InitializePending(impl_->graph_, iter_state->pending_count); + iter_state->dead_count = new std::vector(impl_->graph_->num_node_ids()); + iter_state->input_tensors = new std::vector(impl_->total_tensors_); + + // Activate the successors of the deferred roots in the new iteration. + ActivateNexts(frame, next_iter, ready); + + // Activate the loop invariants in the new iteration. + ActivateLoopInvs(frame, next_iter, ready); +} + +bool ExecutorState::SetOutputFrameIter(const TaggedNode& tagged_node, + const EntryVector& outputs, + FrameState** output_frame, + int64* output_iter, + TaggedNodeSeq* ready) { + const Node* node = tagged_node.node; + FrameState* input_frame = tagged_node.input_frame; + int64 input_iter = tagged_node.input_iter; + bool is_dead = tagged_node.is_dead; + bool is_enter = IsEnter(node); + + if (is_enter) { + FindOrCreateChildFrame(input_frame, input_iter, node, output_frame); + // Propagate if this is a loop invariant. + bool is_constant; + Status s = GetNodeAttr(node->def(), "is_constant", &is_constant); + CHECK(s.ok()) << s; + if (is_constant) { + AddLoopInv(*output_frame, node, outputs[0], ready); + } + --(*output_frame)->num_pending_inputs; + *output_iter = 0; + } else if (IsExit(node)) { + if (is_dead) { + // Stop and remember this node if it is a dead exit. + if (input_iter == input_frame->iteration_count) { + input_frame->dead_exits.push_back(node); + } + *output_frame = nullptr; + } else { + *output_frame = input_frame->parent_frame; + *output_iter = input_frame->parent_iter; + } + } else if (IsNextIteration(node)) { + if (is_dead) { + // Stop the deadness propagation + *output_frame = nullptr; + } else { + if (input_iter == input_frame->iteration_count && + input_frame->num_outstanding_iterations == + input_frame->max_parallel_iterations) { + // Reached the maximum for parallel iterations. + input_frame->next_iter_roots.push_back({node, outputs[0]}); + *output_frame = nullptr; + } else { + // If this is a new iteration, start it. + if (input_iter == input_frame->iteration_count) { + IncrementIteration(input_frame, ready); + } + *output_iter = input_iter + 1; + } + } + } + return is_enter; +} + +void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter, + TaggedNodeSeq* ready) { + int64 curr_iter = iter; + while (curr_iter <= frame->iteration_count && + IsIterationDone(frame, curr_iter)) { + // Delete the iteration curr_iter + VLOG(2) << "Delete iteration [" << frame->frame_name << ", " << curr_iter + << "]."; + + delete frame->GetIteration(curr_iter); + frame->SetIteration(curr_iter, nullptr); + --frame->num_outstanding_iterations; + ++curr_iter; + + // If there is a deferred iteration, start it. + if (frame->next_iter_roots.size() > 0) { + IncrementIteration(frame, ready); + } + } + + if (IsFrameDone(frame)) { + FrameState* parent_frame = frame->parent_frame; + int64 parent_iter = frame->parent_iter; + + // Propagate all the dead exits to the parent frame. + for (const Node* node : frame->dead_exits) { + auto parent_iter_state = parent_frame->GetIteration(parent_iter); + std::vector* pending = parent_iter_state->pending_count; + std::vector* dead_count = parent_iter_state->dead_count; + for (const Edge* e : node->out_edges()) { + const Node* dst_node = e->dst(); + const int dst_id = dst_node->id(); + + bool dst_dead = true; + bool dst_ready = false; + // We know this is a dead input to dst + if (IsMerge(dst_node)) { + if (e->IsControlEdge()) { + (*pending)[dst_id] -= 2; + int count = (*pending)[dst_id]; + dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs()); + dst_ready = (count == 1) || ((count == 0) && dst_dead); + } else { + ++(*dead_count)[dst_id]; + dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs()); + dst_ready = ((*pending)[dst_id] == 0) && dst_dead; + } + } else { + ++(*dead_count)[dst_id]; + dst_ready = (--(*pending)[dst_id] == 0); + } + if (dst_ready) { + ready->push_back( + TaggedNode(dst_node, parent_frame, parent_iter, dst_dead)); + parent_iter_state->outstanding_ops++; + } + } + } + + // Delete the frame + const string& frame_name = frame->frame_name; + VLOG(2) << "Delete frame " << frame_name; + outstanding_frames_.erase(frame_name); + delete frame; + + // Cleanup recursively + if (parent_frame != nullptr) { + parent_frame->GetIteration(parent_iter)->outstanding_frame_count--; + CleanupFramesIterations(parent_frame, parent_iter, ready); + } + } +} + +// When ExecutorImpl graph has no control flow nodes, +// SimpleExecutorState is used instead of ExecutorState. It maintains +// fewer internal state and is convenient for experimenting with async +// op kernels. +class SimpleExecutorState { + public: + SimpleExecutorState(const Executor::Args& args, ExecutorImpl* impl); + ~SimpleExecutorState() { + for (auto it : device_context_map_) { + it.second->Unref(); + } + delete slice_reader_cache_; + } + void RunAsync(Executor::DoneCallback done); + + private: + typedef SimpleExecutorState ME; + + // Not owned. + Rendezvous* rendezvous_; + StepStatsCollector* stats_collector_; + checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_; + FunctionCallFrame* call_frame_; + const ExecutorImpl* impl_; + CancellationManager* cancellation_manager_; + Executor::Args::Runner runner_; + + // Owned. + + // i-th node's j-th input is in tensors_[impl_->nodes[i].input_start + // + j]. The output is either a tensor pointer (pass-by-reference) + // or a tensor (pass-by-value). + // + // NOTE: Not protected by mu_ because tensors_ is resized once. Each + // element of tensors_ is written once by the source node of an edge + // and is cleared by the destination of the same edge. The latter + // node is never run concurrently with the former node. + struct Entry { + Tensor val = *kEmptyTensor; // A tensor value. + Tensor* ref = nullptr; // A tensor reference. + mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr. + + // Every entry carries an optional DeviceContext containing + // Device-specific information about how the Tensor was produced. + DeviceContext* device_context = nullptr; + + // The attributes of the allocator that creates the tensor. + AllocatorAttributes alloc_attr; + }; + + // Contains a map from node id to the DeviceContext object that was + // assigned by the device at the beginning of a step. + DeviceContextMap device_context_map_; + + std::vector input_tensors_; + + // Step-local resource manager. + ResourceMgr step_resource_manager_; + + // Invoked when the execution finishes. + Executor::DoneCallback done_cb_; + + // How many active threads of computation are being used. Same as + // the number of pending Process() functions. + std::atomic_int_fast32_t num_active_; + + mutex mu_; + Status status_ GUARDED_BY(mu_); + + // i-th kernel is still waiting for pending[i] inputs. + class CountDown { + public: + CountDown() : v_(0) {} + void Set(int32 v) { v_.store(v); } + bool Dec() { + return v_.load(std::memory_order_acquire) == 1 || v_.fetch_sub(1) == 1; + } + + private: + std::atomic_int_fast32_t v_; + }; + std::vector pending_; + + // Process Node identified by "id" in current thread. "scheduled_usec" + // indicates when the node becomes ready and gets scheduled. + void Process(int id, int64 scheduled_usec); + + // Before invoking item->kernel, fills in its "inputs". + Status PrepareInputs(const NodeItem& item, TensorValueVec* inputs, + DeviceContextVec* input_device_contexts); + + // After item->kernel computation is done, processes its outputs + // and returns nodes that become "ready". + typedef gtl::InlinedVector ReadyNodeIds; + Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, + ReadyNodeIds* ready, NodeExecStats* stats); + + // "node" just finishes. Takes ownership of "stats". Returns true if + // execution has completed. + bool NodeDone(const Status& s, const Node* node, const ReadyNodeIds& ready, + NodeExecStats* stats, std::deque* inline_ready); + + // Call Process() on all nodes in 'inline_ready'. + void ProcessInline(const std::deque& inline_ready); + + // Schedule all the expensive nodes in 'ready', and put all the inexpensive + // nodes in 'ready' into 'inline_ready'. + void ScheduleReady(const ReadyNodeIds& ready, std::deque* inline_ready); + + // One thread of control finishes. + void Finish(); + + TF_DISALLOW_COPY_AND_ASSIGN(SimpleExecutorState); +}; + +SimpleExecutorState::SimpleExecutorState(const Executor::Args& args, + ExecutorImpl* impl) + : rendezvous_(args.rendezvous), + stats_collector_(args.stats_collector), + slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), + call_frame_(args.call_frame), + impl_(impl), + cancellation_manager_(args.cancellation_manager), + runner_(args.runner), + num_active_(0), + pending_(impl_->nodes_.size()) {} + +void SimpleExecutorState::ProcessInline(const std::deque& inline_ready) { + if (inline_ready.empty()) return; + int64 scheduled_usec = 0; + if (stats_collector_) { + scheduled_usec = nodestats::NowInUsec(); + } + for (int id : inline_ready) { + Process(id, scheduled_usec); + } +} + +void SimpleExecutorState::ScheduleReady(const ReadyNodeIds& ready, + std::deque* inline_ready) { + if (ready.empty()) return; + + int64 scheduled_usec = 0; + if (stats_collector_) { + scheduled_usec = nodestats::NowInUsec(); + } + if (inline_ready == nullptr) { + // Schedule to run all the ready ops in thread pool. + for (auto id : ready) { + runner_(std::bind(&ME::Process, this, id, scheduled_usec)); + } + return; + } + const std::vector& nodes = impl_->nodes_; + int curr_expensive_node = -1; + for (auto id : ready) { + if (!nodes[id].kernel->IsExpensive()) { + // Inline this inexpensive node. + inline_ready->push_back(id); + } else { + if (curr_expensive_node != -1) { + // Dispatch to another thread since there is plenty of work to + // do for this thread. + runner_( + std::bind(&ME::Process, this, curr_expensive_node, scheduled_usec)); + } + curr_expensive_node = id; + } + } + if (curr_expensive_node != -1) { + if (inline_ready->empty()) { + // Tail recursion optimization + inline_ready->push_back(curr_expensive_node); + } else { + // There are inline nodes to run already. We dispatch this expensive + // node to other thread. + runner_( + std::bind(&ME::Process, this, curr_expensive_node, scheduled_usec)); + } + } +} + +void SimpleExecutorState::RunAsync(Executor::DoneCallback done) { + const Graph* graph = impl_->graph_; + ReadyNodeIds ready; + + // Ask the device to fill in the device context map. + Device* device = impl_->params_.device; + device->FillContextMap(graph, &device_context_map_); + + for (const Node* n : graph->nodes()) { + const int id = n->id(); + const int num_in_edges = n->in_edges().size(); + pending_[id].Set(num_in_edges); + if (num_in_edges == 0) { + ready.push_back(id); + } + } + if (ready.empty()) { + done(Status::OK()); + } else { + num_active_ = ready.size(); + done_cb_ = done; + input_tensors_.resize(impl_->total_tensors_); + // Schedule to run all the ready ops in thread pool. + ScheduleReady(ready, nullptr); + } +} + +Status SimpleExecutorState::PrepareInputs( + const NodeItem& item, TensorValueVec* inputs, + DeviceContextVec* input_device_contexts) { + const Node* node = item.node; + + inputs->clear(); + inputs->resize(node->num_inputs()); + input_device_contexts->clear(); + input_device_contexts->resize(node->num_inputs()); + + for (int i = 0; i < node->num_inputs(); ++i) { + const bool expect_ref = IsRefType(node->input_type(i)); + Entry* entry = input_tensors_.data() + item.input_start + i; + (*input_device_contexts)[i] = entry->device_context; + + // i-th input. + TensorValue* inp = &(*inputs)[i]; + + if (entry->ref == nullptr) { + if (expect_ref) { + return AttachDef( + errors::InvalidArgument(i, "-th input expects a ref type"), + item.kernel->def()); + } + inp->tensor = &entry->val; + } else { + if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) { + return AttachDef( + errors::FailedPrecondition("Attempting to use uninitialized value ", + item.kernel->def().input(i)), + item.kernel->def()); + } + if (expect_ref) { + inp->mutex_if_ref = entry->ref_mu; + inp->tensor = entry->ref; + } else { + // Automatically deref the tensor ref when the op expects a + // tensor but is given a ref to a tensor. Need to deref it + // under the mutex. + { + mutex_lock l(*(entry->ref_mu)); + entry->val = *entry->ref; + } + inp->tensor = &entry->val; + } + } + } + return Status::OK(); +} + +void SimpleExecutorState::Process(int id, int64 scheduled_usec) { + const std::vector& nodes = impl_->nodes_; + ReadyNodeIds ready; + std::deque inline_ready; + + // Parameters passed to OpKernel::Compute. + TensorValueVec inputs; + DeviceContextVec input_device_contexts; + + OpKernelContext::Params params; + Device* device = impl_->params_.device; + params.device = device; + // track allocations if and only if we are collecting statistics + params.track_allocations = (stats_collector_ != nullptr); + params.rendezvous = rendezvous_; + params.cancellation_manager = cancellation_manager_; + params.call_frame = call_frame_; + params.function_library = impl_->params_.function_library; + params.resource_manager = device->resource_manager(); + params.step_resource_manager = &step_resource_manager_; + params.slice_reader_cache = slice_reader_cache_; + params.inputs = &inputs; + params.input_device_contexts = &input_device_contexts; + params.frame_iter = FrameAndIter(0, 0); + + Status s; + NodeExecStats* stats = nullptr; + bool completed = false; + inline_ready.push_back(id); + while (!inline_ready.empty()) { + id = inline_ready.front(); + inline_ready.pop_front(); + const NodeItem& item = nodes[id]; + const Node* node = item.node; + + // Set the device_context for this node id, if it exists. + auto dc_it = device_context_map_.find(id); + if (dc_it != device_context_map_.end()) { + params.op_device_context = dc_it->second; + } + + if (stats_collector_) { + stats = new NodeExecStats; + stats->set_node_name(node->name()); + nodestats::SetScheduled(stats, scheduled_usec); + nodestats::SetAllStart(stats); + } + + VLOG(1) << "Process node: " << id << " " << SummarizeNodeDef(node->def()); + + // Prepares inputs. + s = PrepareInputs(item, &inputs, &input_device_contexts); + if (!s.ok()) { + // Continue to process the nodes in 'inline_ready'. + completed = NodeDone(s, item.node, ready, stats, &inline_ready); + continue; + } + + OpKernel* op_kernel = item.kernel; + params.op_kernel = op_kernel; + params.output_alloc_attr = [this, node, op_kernel](int index) { + return OutputAttributes(&impl_->alloc_attr_, node, op_kernel, index); + }; + + // Asynchronous computes. + AsyncOpKernel* async = op_kernel->AsAsync(); + if (async) { + auto pcopy = CopyParams(params); + auto ctx = new OpKernelContext(*pcopy); + auto done = [this, item, ctx, stats, pcopy]() { + VLOG(2) << this + << " Async kernel done: " << SummarizeNodeDef(item.node->def()); + if (stats_collector_) nodestats::SetOpEnd(stats); + ReadyNodeIds ready; + Status s = ProcessOutputs(item, ctx, &ready, stats); + if (stats_collector_) nodestats::SetMemory(stats, ctx); + // Schedule to run all the ready ops in thread pool. + bool completed = NodeDone(s, item.node, ready, stats, nullptr); + delete ctx; + DeleteParams(pcopy); + if (completed) Finish(); + }; + if (stats_collector_) nodestats::SetOpStart(stats); + device->ComputeAsync(async, ctx, done); + } else { + // Synchronous computes. + OpKernelContext ctx(params); + if (stats_collector_) nodestats::SetOpStart(stats); + device->Compute(CHECK_NOTNULL(op_kernel), &ctx); + if (stats_collector_) nodestats::SetOpEnd(stats); + + s = ProcessOutputs(item, &ctx, &ready, stats); + if (stats_collector_) nodestats::SetMemory(stats, &ctx); + if (stats_collector_) { + scheduled_usec = nodestats::NowInUsec(); + } + completed = NodeDone(s, node, ready, stats, &inline_ready); + } + } // while !inline_ready.empty() + + // This thread of computation is done if completed = true. + if (completed) Finish(); +} + +bool SimpleExecutorState::NodeDone(const Status& s, const Node* node, + const ReadyNodeIds& ready, + NodeExecStats* stats, + std::deque* inline_ready) { + if (stats_collector_) { + nodestats::SetAllEnd(stats); + if (!SetTimelineLabel(node, stats)) { + // Only record non-transfer nodes. + stats_collector_->Save(impl_->params_.device->name(), stats); + } else { + delete stats; + } + } + + Rendezvous* captured_rendezvous = nullptr; // Will be set on error. + if (!s.ok()) { + // Some error happened. This thread of computation is done. + mutex_lock l(mu_); + if (status_.ok()) { + captured_rendezvous = rendezvous_; + if (captured_rendezvous) captured_rendezvous->Ref(); + status_ = s; + } + } + if (captured_rendezvous) { + // If we captured the rendezvous_ pointer, we are in an error condition. + // Use captured_rendezvous, in case "this" is deleted by another thread. + TRACEPRINTF("StartAbort: %s", s.ToString().c_str()); + captured_rendezvous->StartAbort(s); + captured_rendezvous->Unref(); + } + + bool completed = false; + int ready_size = ready.size(); + if (ready_size == 0 || !s.ok()) { + completed = (num_active_.fetch_sub(1) == 1); + } else if (ready_size > 1) { + num_active_.fetch_add(ready_size - 1, std::memory_order_relaxed); + } + + // Schedule the ready nodes in 'ready'. + if (s.ok()) { + ScheduleReady(ready, inline_ready); + } + return completed; +} + +void SimpleExecutorState::Finish() { + mu_.lock(); + auto ret = status_; + auto done_cb = done_cb_; + auto runner = runner_; + mu_.unlock(); + delete this; + CHECK(done_cb != nullptr); + runner([done_cb, ret]() { done_cb(ret); }); +} + +Status SimpleExecutorState::ProcessOutputs(const NodeItem& item, + OpKernelContext* ctx, + ReadyNodeIds* ready, + NodeExecStats* stats) { + Status s = ctx->status(); + if (!s.ok()) { + s = AttachDef(s, item.kernel->def()); + LOG(WARNING) << this << " Compute status: " << s; + return s; + } + + // Processes outputs. + gtl::InlinedVector outputs; + const Node* node = item.node; + outputs.resize(node->num_outputs()); + + // Get the device_context for this node id, if it exists. + DeviceContext* device_context = nullptr; + auto dc_it = device_context_map_.find(node->id()); + if (dc_it != device_context_map_.end()) { + device_context = dc_it->second; + } + + for (int i = 0; i < node->num_outputs(); ++i) { + TensorValue val = ctx->release_output(i); + // Sanity check of output tensor types. + DataType dtype = val->dtype(); + if (val.is_ref()) dtype = MakeRefType(dtype); + if (dtype == node->output_type(i)) { + Entry* out = &(outputs[i]); + if (val.is_ref()) { + out->ref = val.tensor; + out->ref_mu = val.mutex_if_ref; + } else { + out->val = *val.tensor; + } + + // Set the device context of the output entry. + out->device_context = device_context; + + // Set the allocator attributes of the output entry. + out->alloc_attr = ctx->output_alloc_attr(i); + + if (stats_collector_ && val.tensor->IsInitialized()) { + nodestats::SetOutput(stats, i, ctx->output_allocation_type(i), + val.tensor); + } + } else { + s.Update( + errors::Internal("Output ", i, " of type ", DataTypeString(dtype), + " does not match declared output type ", + DataTypeString(node->output_type(i)), + " for operation ", SummarizeNodeDef(node->def()))); + } + if (!val.is_ref()) { + // If OpKernelContext returns outputs via pass-by-value, we + // don't need this trouble. + delete val.tensor; + } + } + if (!s.ok()) return s; + + // Clears inputs. + for (int i = 0; i < node->num_inputs(); ++i) { + input_tensors_[item.input_start + i].val = *kEmptyTensor; + } + + // Propagates outputs along out edges. + ready->clear(); + const std::vector& nodes = impl_->nodes_; + for (const Edge* e : node->out_edges()) { + const int src_slot = e->src_output(); + const int dst_id = e->dst()->id(); + const NodeItem& dst_item = nodes[dst_id]; + if (!e->IsControlEdge()) { + const int dst_slot = e->dst_input(); + input_tensors_[dst_item.input_start + dst_slot] = outputs[src_slot]; + } + if (pending_[dst_id].Dec()) { + ready->push_back(dst_id); + } + } + return Status::OK(); +} + +// NOTE(yuanbyu): Use the executor that supports control flow by default. +const bool use_control_flow_executor = true; +void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { + if (params_.has_control_flow || use_control_flow_executor) { + (new ExecutorState(args, this))->RunAsync(done); + } else { + (new SimpleExecutorState(args, this))->RunAsync(done); + } +} + +} // end namespace + +Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph, + Executor** executor) { + ExecutorImpl* impl = new ExecutorImpl(params, graph); + Status s = impl->Initialize(); + if (s.ok()) { + *executor = impl; + } else { + delete impl; + } + return s; +} + +Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, + const NodeDef& ndef, OpKernel** kernel) { + auto device_type = DeviceType(device->attributes().device_type()); + auto allocator = device->GetAllocator(AllocatorAttributes()); + return CreateOpKernel(device_type, device, allocator, flib, ndef, kernel); +} + +void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; } + +Status CreateCachedKernel(Device* device, const string& session, + FunctionLibraryRuntime* flib, const NodeDef& ndef, + OpKernel** kernel) { + auto op_seg = device->op_segment(); + auto create_fn = [device, flib, &ndef](OpKernel** kernel) { + return CreateNonCachedKernel(device, flib, ndef, kernel); + }; + return op_seg->FindOrCreate(session, ndef.name(), kernel, create_fn); +} + +// Deletes "kernel". +void DeleteCachedKernel(Device* device, const string& session, + OpKernel* kernel) { + // Do nothing. +} + +} // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h new file mode 100644 index 0000000000..82bcbab836 --- /dev/null +++ b/tensorflow/core/common_runtime/executor.h @@ -0,0 +1,209 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_ +#define TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +class StepStatsCollector; + +// Executor runs a graph computation. +// Example: +// Graph* graph = ...; +// ... construct graph ... +// Executor* executor; +// TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor)); +// Rendezvous* rendezvous = NewNaiveRendezvous(); +// TF_CHECK_OK(rendezvous->Send("input", some_input_tensor)); +// TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr})); +// TF_CHECK_OK(rendezvous->Recv("input", &output_tensor)); +// ... ... +// +// Multiple threads can call Executor::Run concurrently. +class Executor { + public: + virtual ~Executor() {} + + // RunAsync() executes the graph computation. "done" is run when the + // graph computation completes. If any error happens during the + // computation, "done" is run and the error is passed to "done". + // + // RunAsync() is given a few arguments in Args. The caller must + // ensure objects passed in Args (rendezvous, stats_collector, etc.) + // are alive at least until done is invoked. All pointers to the + // argument objects can be nullptr. + // + // RunAsync() uses the given "rendezvous", if not null, as the + // mechanism to communicate inputs and outputs of the underlying + // graph computation. + // + // RunAsync() calls "stats_collector", if not null, to keep track of + // stats. This allows us to collect statistics and traces on demand. + // + // RunAsync() is provided a "call_frame", if the executor is used + // for executing a function, is used to pass arguments and return + // values between the caller and the callee. + // + // RunAsync() uses "cancellation_manager", if not nullptr, to + // register callbacks that should be called if the graph computation + // is cancelled. Note that the callbacks merely unblock any + // long-running computation, and a cancelled step will terminate by + // returning/calling the DoneCallback as usual. + // + // RunAsync() dispatches closures to "runner". Typically, "runner" + // is backed up by a bounded threadpool. + struct Args { + Rendezvous* rendezvous = nullptr; + StepStatsCollector* stats_collector = nullptr; + FunctionCallFrame* call_frame = nullptr; + CancellationManager* cancellation_manager = nullptr; + + typedef std::function Closure; + typedef std::function Runner; + Runner runner = nullptr; + }; + typedef std::function DoneCallback; + virtual void RunAsync(const Args& args, DoneCallback done) = 0; + + // Synchronous wrapper for RunAsync(). + Status Run(const Args& args) { + Status ret; + Notification n; + RunAsync(args, [&ret, &n](const Status& s) { + ret = s; + n.Notify(); + }); + n.WaitForNotification(); + return ret; + } +}; + +// Creates an Executor that computes the given "graph". +// +// If successful, returns the constructed executor in "*executor". The +// caller keeps the ownership of "device". The returned executor takes +// the ownership of "graph". Otherwise, returns an error status. +// +// "params" provides a set of context for the executor. We expect that +// different context would provide different implementations. +struct LocalExecutorParams { + Device* device; + + // The library runtime support. + FunctionLibraryRuntime* function_library; + + // True iff the computation contains control flow nodes. + bool has_control_flow; + + // create_kernel returns an instance of op kernel based on NodeDef. + // delete_kernel is called for every kernel used by the executor + // when the executor is deleted. + std::function create_kernel; + std::function delete_kernel; +}; +::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params, + const Graph* graph, Executor** executor); + +// A class to help run multiple executors in parallel and wait until +// all of them are complete. +// +// ExecutorBarrier deletes itself after the function returned by Get() +// is called. +class ExecutorBarrier { + public: + typedef std::function StatusCallback; + + // Create an ExecutorBarrier for 'num' different executors. + // + // 'r' is the shared Rendezvous object that is used to communicate + // state. If any of the executors experiences an error, the + // rendezvous object will be aborted exactly once. + // + // 'done' is called after the last executor completes, and + // ExecutorBarrier is deleted. + ExecutorBarrier(int num, Rendezvous* r, StatusCallback done) + : rendez_(r), done_cb_(done), pending_(num) {} + + ~ExecutorBarrier() {} + + // Returns a closure that Executors must call when they are done + // computing, passing the status of their execution as an argument. + StatusCallback Get() { + return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1); + } + + private: + Rendezvous* rendez_ = nullptr; + StatusCallback done_cb_ = nullptr; + + mutable mutex mu_; + int pending_ GUARDED_BY(mu_) = 0; + Status status_ GUARDED_BY(mu_); + + void WhenDone(const Status& s) { + bool error = false; + StatusCallback done = nullptr; + Status status; + { + mutex_lock l(mu_); + // If we are the first error encountered, mark the status + // appropriately and later trigger an abort of the Rendezvous + // object by this thread only. + if (status_.ok() && !s.ok()) { + error = true; + status_ = s; + } + + // If this is the last call to WhenDone, call the final callback + // below. + if (--pending_ == 0) { + CHECK(done_cb_ != nullptr); + done = done_cb_; + done_cb_ = nullptr; + } + status = status_; + } + if (error) { + rendez_->StartAbort(status); + } + if (done != nullptr) { + delete this; + done(status); + } + } + + TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier); +}; + +// A few helpers to facilitate create/delete kernels. + +// Creates a kernel based on "ndef" on device "device". The kernel can +// access the functions in the "flib". The caller takes ownership of +// returned "*kernel". +Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, + const NodeDef& ndef, OpKernel** kernel); + +// Deletes "kernel" returned by CreateKernel. +void DeleteNonCachedKernel(OpKernel* kernel); + +// Creates a kernel based on "ndef" on device "device". The kernel can +// access the functions in the "flib". The caller does not take +// ownership of returned "*kernel". If a kernel has been created for +// ndef.name(), returns the same kernel instance. +Status CreateCachedKernel(Device* device, const string& session, + FunctionLibraryRuntime* flib, const NodeDef& ndef, + OpKernel** kernel); + +// Deletes "kernel" returned by CreateCachedKernel. +void DeleteCachedKernel(Device* device, const string& session, + OpKernel* kernel); + +} // end namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_ diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc new file mode 100644 index 0000000000..2b1a041235 --- /dev/null +++ b/tensorflow/core/common_runtime/function.cc @@ -0,0 +1,1335 @@ +#include "tensorflow/core/common_runtime/function.h" + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/optimizer_cse.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { + +// A few string constant used throughout this module. +static const char* const kArgOp = "_Arg"; +static const char* const kRetOp = "_Retval"; +static const char* const kGradientOp = "SymbolicGradient"; +static const char* const kNodeLabel = "Func"; + +// Represents the index-th output of a node. +struct Endpoint { + Node* node; + int index; + + // Returns the string name represents this endpoint. + string name() const { + if (index == 0) { + return node->name(); + } else { + return strings::StrCat(node->name(), ":", index); + } + } + + DataType dtype() const { return node->output_type(index); } +}; + +struct EndpointHash { + uint64 operator()(const Endpoint& x) const { + return Hash64(reinterpret_cast(&x.node), sizeof(Node*), + x.index); + } +}; + +struct EndpointEq { + bool operator()(const Endpoint& x, const Endpoint& y) const { + return (x.node == y.node) && (x.index == y.index); + } +}; + +// The following Add* routines are used to add a few graph nodes while +// functions are transformed. +static Node* AddNoOp(Graph* g) { + NodeDef ndef; + ndef.set_name(g->NewName(kNodeLabel)); + ndef.set_op("NoOp"); + Status s; + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + return ret; +} + +static Node* AddIdentity(Graph* g, Endpoint input) { + DCHECK_LT(0, input.dtype()); + DCHECK_LT(input.dtype(), DT_FLOAT_REF); + NodeDef ndef; + ndef.set_name(g->NewName(kNodeLabel)); + ndef.set_op("Identity"); + ndef.add_input(input.name()); + AddNodeAttr("T", input.dtype(), &ndef); + Status s; + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + g->AddEdge(input.node, input.index, ret, 0); + return ret; +} + +static Node* AddArg(Graph* g, DataType dtype, int index) { + DCHECK_LT(0, dtype); + DCHECK_LT(dtype, DT_FLOAT_REF); + NodeDef ndef; + ndef.set_name(g->NewName(kNodeLabel)); + ndef.set_op(kArgOp); + AddNodeAttr("T", dtype, &ndef); + AddNodeAttr("index", index, &ndef); + Status s; + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + return ret; +} + +static Node* AddRet(Graph* g, Endpoint input, int index) { + DCHECK_LT(0, input.dtype()); + DCHECK_LT(input.dtype(), DT_FLOAT_REF); + NodeDef ndef; + ndef.set_name(g->NewName(kNodeLabel)); + ndef.set_op(kRetOp); + ndef.add_input(input.name()); + AddNodeAttr("T", input.dtype(), &ndef); + AddNodeAttr("index", index, &ndef); + Status s; + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + g->AddEdge(input.node, input.index, ret, 0); + return ret; +} + +static Node* AddZerosLike(Graph* g, Endpoint input) { + DCHECK_LT(0, input.dtype()); + DCHECK_LT(input.dtype(), DT_FLOAT_REF); + NodeDef ndef; + ndef.set_name(g->NewName(kNodeLabel)); + ndef.set_op("ZerosLike"); + ndef.add_input(input.name()); + AddNodeAttr("T", input.dtype(), &ndef); + Status s; + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + g->AddEdge(input.node, input.index, ret, 0); + return ret; +} + +static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice grads) { + const int num_x = n->num_inputs(); + const int num_y = n->num_outputs(); + CHECK_EQ(num_y, grads.size()); + + NodeDef ndef; + ndef.set_name(g->NewName(kNodeLabel)); + ndef.set_op(kGradientOp); + + // The gradient node should have num_x + num_y inputs. + std::vector n_inputs(num_x); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) continue; + n_inputs[e->dst_input()] = {e->src(), e->src_output()}; + } + DataTypeVector in_types; + for (const Endpoint& ep : n_inputs) { + ndef.add_input(ep.name()); + in_types.push_back(ep.dtype()); + } + for (const Endpoint& ep : grads) { + ndef.add_input(ep.name()); + in_types.push_back(ep.dtype()); + } + CHECK_EQ(ndef.input_size(), num_x + num_y); + + AddNodeAttr("Tin", in_types, &ndef); + + // The gradient node's outputs have the same types as the node 'n's + // inputs. + AddNodeAttr("Tout", n->input_types(), &ndef); + NameAttrList func; + func.set_name(n->type_string()); + *(func.mutable_attr()) = n->def().attr(); + AddNodeAttr("f", func, &ndef); + Status s; + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + return ret; +} + +class ArgOp : public OpKernel { + public: + explicit ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); + } + + void Compute(OpKernelContext* ctx) override { + auto frame = ctx->call_frame(); + OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame")); + Tensor val; + OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val)); + OP_REQUIRES(ctx, val.dtype() == dtype_, + errors::InvalidArgument( + "Type mismatch: actual ", DataTypeString(val.dtype()), + " vs. expect ", DataTypeString(dtype_))); + ctx->set_output(0, val); + } + + private: + int index_; + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(ArgOp); +}; + +REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp); +REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_GPU), ArgOp); + +class RetvalOp : public OpKernel { + public: + explicit RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& val = ctx->input(0); + OP_REQUIRES(ctx, val.dtype() == dtype_, + errors::InvalidArgument( + "Type mismatch: actual ", DataTypeString(val.dtype()), + " vs. expect ", DataTypeString(dtype_))); + auto frame = ctx->call_frame(); + OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame")); + OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val)); + } + + private: + int index_; + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); +}; + +REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp); +REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_GPU), RetvalOp); + +static const FunctionLibraryRuntime::Handle kInvalidHandle = -1; + +class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { + public: + FunctionLibraryRuntimeImpl(Device* device, Runner runner, + const FunctionLibraryDefinition* lib_def); + + ~FunctionLibraryRuntimeImpl() override; + + Status Instantiate(const string& function_name, + const InstantiateAttrValueMap& attrs, + Handle* handle) override; + + const FunctionBody* GetFunctionBody(Handle handle) override; + + Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; + + void Run(const Options& opts, Handle handle, gtl::ArraySlice args, + std::vector* rets, DoneCallback done) override; + + bool IsDefined(const string& function_name) override; + + private: + typedef FunctionLibraryRuntimeImpl ME; + + Device* const device_; + Runner runner_ = nullptr; + const FunctionLibraryDefinition* const lib_def_; + std::function get_func_sig_; + std::function create_kernel_; + + mutable mutex mu_; + + // Maps function instantiation to a handle. The key is a + // canonicalized representation of the function name and + // instantiation attrs. The handle is an index into the items_. + std::unordered_map table_ GUARDED_BY(mu_); + + // func_graphs_ never shrinks or reorders its members. + std::vector func_graphs_ GUARDED_BY(mu_); + + // The instantiated and transformed function is encoded as a Graph + // object, and an executor is created for the graph. + struct Item : public core::RefCounted { + Executor* exec = nullptr; + + ~Item() override { delete this->exec; } + }; + std::vector items_; + + Status FunctionDefToBody(const FunctionDef& fdef, + const InstantiateAttrValueMap& attrs, + FunctionBody** fbody); + Status CreateItem(Handle handle, Item** item); + Status GetOrCreateItem(Handle handle, Item** item); + Status InstantiateSymbolicGradient(const InstantiateAttrValueMap& attrs, + FunctionBody** g_body); + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl); +}; + +FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( + Device* device, Runner runner, const FunctionLibraryDefinition* lib_def) + : device_(device), runner_(runner), lib_def_(lib_def) { + get_func_sig_ = [this](const string& op, const OpDef** sig) { + Status s; + *sig = lib_def_->LookUp(op, &s); + return s; + }; + create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) { + return CreateKernel(ndef, kernel); + }; +} + +FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() { + for (FunctionBody* p : func_graphs_) delete p; + for (Item* item : items_) + if (item) item->Unref(); +} + +// An asynchronous op kernel which executes an instantiated function +// defined in a library. +class CallOp : public AsyncOpKernel { + public: + CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), handle_(handle) {} + + ~CallOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + FunctionLibraryRuntime* lib = ctx->function_library(); + OP_REQUIRES_ASYNC(ctx, lib != nullptr, + errors::Internal("No function library is provided."), + done); + FunctionLibraryRuntime::Options opts; + std::vector args; + args.reserve(ctx->num_inputs()); + for (int i = 0; i < ctx->num_inputs(); ++i) { + args.push_back(ctx->input(i)); + } + std::vector* rets = new std::vector; + lib->Run(opts, handle_, args, rets, + [ctx, done, rets](const Status& status) { + if (!status.ok()) { + ctx->SetStatus(status); + } else { + CHECK_EQ(rets->size(), ctx->num_outputs()); + for (size_t i = 0; i < rets->size(); ++i) { + ctx->set_output(i, (*rets)[i]); + } + } + delete rets; + done(); + }); + } + + private: + FunctionLibraryRuntime::Handle handle_; + + TF_DISALLOW_COPY_AND_ASSIGN(CallOp); +}; + +const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) { + mutex_lock l(mu_); + CHECK_LE(0, h); + CHECK_LT(h, func_graphs_.size()); + return func_graphs_[h]; +} + +Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, + OpKernel** kernel) { + if (ndef.op() != kGradientOp && (lib_def_->Find(ndef.op()) == nullptr)) { + return CreateNonCachedKernel(device_, this, ndef, kernel); + } + + // Try to instantiate this function for the func/attr. Maybe its + // cached already. + Handle handle; + TF_RETURN_IF_ERROR(Instantiate(ndef.op(), ndef.attr(), &handle)); + + const FunctionBody* fbody = GetFunctionBody(handle); + CHECK_NOTNULL(fbody); + + // Constructs a CallOp kernel for running the instantiated function. + Status s; + auto device_type = DeviceType(device_->attributes().device_type()); + OpKernelConstruction construction( + device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef, + &fbody->fdef.signature(), this, fbody->arg_types, fbody->ret_types, &s); + *kernel = new CallOp(handle, &construction); + if (!s.ok()) { + delete kernel; + } + return s; +} + +Status FunctionLibraryRuntimeImpl::FunctionDefToBody( + const FunctionDef& fdef, const InstantiateAttrValueMap& attrs, + FunctionBody** fbody) { + // Instantiates the function template into a graph def. + InstantiationResult result; + TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig_, &result)); + + Graph* graph = new Graph(lib_def_); + GraphConstructorOptions opts; + opts.allow_internal_ops = true; + opts.expect_device_spec = false; + Status s = ConvertGraphDefToGraph(opts, result.gdef, graph); + if (!s.ok()) { + delete graph; + } else { + *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types, graph); + } + return s; +} + +Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( + const InstantiateAttrValueMap& attrs, FunctionBody** g_body) { + const AttrValue* f = gtl::FindOrNull(attrs, "f"); + if (f == nullptr) { + return errors::InvalidArgument("SymbolicGradient is missing attr: f"); + } + const auto& func = f->func(); + const FunctionDef* fdef = lib_def_->Find(func.name()); + if (fdef == nullptr) { + // f is a primitve op. + gradient::Creator creator; + TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator)); + if (creator == nullptr) { + return errors::InvalidArgument("No gradient is defined for ", + func.name()); + } + FunctionDef grad_fdef; + TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef)); + TF_RETURN_IF_ERROR(FunctionDefToBody(grad_fdef, func.attr(), g_body)); + } else { + // f is a user-defined function. + Handle f_handle; + TF_RETURN_IF_ERROR(Instantiate(func.name(), func.attr(), &f_handle)); + const FunctionBody* f_body = GetFunctionBody(f_handle); + CHECK_NOTNULL(f_body); + *g_body = SymbolicGradient(*f_body); + } + return Status::OK(); +} + +Status FunctionLibraryRuntimeImpl::Instantiate( + const string& function_name, const InstantiateAttrValueMap& attrs, + Handle* handle) { + const string key = Canonicalize(function_name, attrs); + { + mutex_lock l(mu_); + *handle = gtl::FindWithDefault(table_, key, kInvalidHandle); + if (*handle != kInvalidHandle) { + return Status::OK(); + } + } + + Status s; + FunctionBody* fbody = nullptr; + if (function_name == kGradientOp) { + TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(attrs, &fbody)); + } else { + const FunctionDef* fdef = lib_def_->Find(function_name); + if (fdef == nullptr) { + return errors::NotFound("Function ", function_name, " is not defined."); + } + TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, &fbody)); + } + + { + mutex_lock l(mu_); + *handle = gtl::FindWithDefault(table_, key, kInvalidHandle); + if (*handle != kInvalidHandle) { + delete fbody; + } else { + *handle = func_graphs_.size(); + table_.insert({key, *handle}); + func_graphs_.push_back(fbody); + items_.resize(func_graphs_.size()); + } + } + return Status::OK(); +} + +static void DumpGraph(const char* label, const Graph* g) { + if (VLOG_IS_ON(1)) { + LOG(INFO) << label << ": " << std::endl << DebugString(g); + } +} + +static void SimplifyGraph(Graph* g) { + if (RemoveListArrayConverter(g)) { + DumpGraph("RemoveListArrayConverter", g); + } + bool changed; + do { + changed = false; + if (RemoveDeadNodes(g)) { + changed = true; + DumpGraph("RemoveDeadNodes", g); + } + if (RemoveIdentityNodes(g)) { + changed = true; + DumpGraph("RemoveIdentityNodes", g); + } + FixupSourceAndSinkEdges(g); + OptimizeCSE(g, nullptr); + DumpGraph("OptimizeCSE", g); + } while (changed); +} + +void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) { + DumpGraph("Initial", *g); + const int kNumInlineRounds = 10; + for (int i = 0; i < kNumInlineRounds; ++i) { + if (!ExpandInlineFunctions(lib, *g)) break; + DumpGraph("ExpandInlineFunctions", *g); + SimplifyGraph(*g); + } + + // Makes a copy so that we densify node ids. + Graph* copy = new Graph((*g)->op_registry()); + CopyGraph(**g, copy); + delete *g; + *g = copy; + DumpGraph("ReCopy", *g); +} + +Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { + const FunctionBody* fbody = GetFunctionBody(handle); + CHECK_NOTNULL(fbody); + Graph* g = new Graph(lib_def_); + CopyGraph(*fbody->graph, g); + OptimizeGraph(this, &g); + + // Creates an executor based on the g. This must be done without + // holding mu_ because create_kernel_ calls back into the library. + LocalExecutorParams params; + params.device = device_; + params.function_library = this; + params.has_control_flow = false; + params.create_kernel = create_kernel_; + params.delete_kernel = [](OpKernel* kernel) { + DeleteNonCachedKernel(kernel); + }; + Executor* exec; + TF_RETURN_IF_ERROR(NewLocalExecutor(params, g, &exec)); + + *item = new Item; + (*item)->exec = exec; + return Status::OK(); +} + +Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { + { + mutex_lock l(mu_); + if (handle >= items_.size()) { + return errors::NotFound("Function handle ", handle, + " is not valid. Likely an internal error."); + } + *item = items_[handle]; + if (*item != nullptr) { + (*item)->Ref(); + return Status::OK(); + } + } + // NOTE: We need to call CreateItem out of mu_ because creating an + // executor needs to call CreateKernel. + TF_RETURN_IF_ERROR(CreateItem(handle, item)); + + { + mutex_lock l(mu_); + if (items_[handle] == nullptr) { + // Install *item in items_. + items_[handle] = *item; + (*item)->Ref(); + } + } + return Status::OK(); +} + +void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, + gtl::ArraySlice args, + std::vector* rets, + DoneCallback done) { + if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { + return done(errors::Cancelled("")); + } + const FunctionBody* fbody = GetFunctionBody(handle); + FunctionCallFrame* frame = + new FunctionCallFrame(fbody->arg_types, fbody->ret_types); + Status s = frame->SetArgs(args); + if (!s.ok()) { + delete frame; + return done(s); + } + Item* item = nullptr; + s = GetOrCreateItem(handle, &item); + if (!s.ok()) { + delete frame; + return done(s); + } + Executor::Args exec_args; + exec_args.call_frame = frame; + exec_args.cancellation_manager = opts.cancellation_manager; + exec_args.runner = runner_; + item->exec->RunAsync( + // Executor args + exec_args, + // Done callback. + [item, frame, rets, done](const Status& status) { + item->Unref(); + Status s = status; + if (s.ok()) { + s = frame->GetRetvals(rets); + } + delete frame; + done(s); + }); +} + +bool FunctionLibraryRuntimeImpl::IsDefined(const string& function_name) { + return lib_def_->Find(function_name) != nullptr; +} + +FunctionLibraryRuntime* NewFunctionLibraryRuntime( + Device* device, Runner runner, const FunctionLibraryDefinition* lib_def) { + return new FunctionLibraryRuntimeImpl(device, runner, lib_def); +} + +bool RemoveDeadNodes(Graph* g) { + std::vector visited(g->num_node_ids(), false); + visited[Graph::kSourceId] = true; + visited[Graph::kSinkId] = true; + std::deque q; + for (auto n : g->nodes()) { + if (n->op_def().is_stateful()) { + visited[n->id()] = true; + } else if (n->type_string() == kArgOp) { + visited[n->id()] = true; + } else if (n->type_string() == kRetOp) { + visited[n->id()] = true; + q.push_back(n); + } + } + while (!q.empty()) { + const Node* n = q.front(); + q.pop_front(); + visited[n->id()] = true; + for (auto e : n->in_edges()) { + q.push_back(e->src()); + } + } + bool removed_any = false; + for (Node* n : g->nodes()) { + if (!visited[n->id()]) { + g->RemoveNode(n); + removed_any = true; + } + } + return removed_any; +} + +namespace { +// If 'edges' contains only 1 non-control edge, returns it. Otherwise, +// returns a nullptr. +const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) { + const Edge* ret = nullptr; + for (const Edge* e : edges) { + if (e->IsControlEdge() || ret) return nullptr; + ret = e; + } + return ret; +} +} // end namespace + +bool RemoveIdentityNodes(Graph* g) { + bool removed_any = false; + gtl::InlinedVector matches; + for (Node* n : g->nodes()) { + if ((n->type_string() == "Identity") && GetTheOnlyDataEdge(n->in_edges())) { + matches.push_back(n); + } + } + if (!matches.empty()) { + for (Node* n : matches) { + const Edge* in = GetTheOnlyDataEdge(n->in_edges()); + for (const Edge* out : n->out_edges()) { + if (out->IsControlEdge()) { + g->AddControlEdge(in->src(), out->dst()); + } else { + g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input()); + } + } + g->RemoveNode(n); + removed_any = true; + } + } + return removed_any; +} + +bool RemoveListArrayConverter(Graph* g) { + gtl::InlinedVector matches; + for (Node* n : g->nodes()) { + if ((n->type_string() == "_ListToArray") || + (n->type_string() == "_ArrayToList")) { + matches.push_back(n); + } + } + bool removed_any = false; + if (!matches.empty()) { + for (Node* n : matches) { + if (n->num_inputs() != n->num_outputs()) { + continue; // Not expected. Skip. + } + gtl::InlinedVector identity_nodes(n->num_inputs(), nullptr); + + // Process input edges first. + Node* input_control_node = nullptr; + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + if (input_control_node == nullptr) { + // If node "n" has any control dependencies, adds a no-op + // node (input_control_node) which the additional Identity + // nodes depends on and the input_control_node depends on + // the node "n"s control dependencies. + input_control_node = AddNoOp(g); + } + g->AddControlEdge(e->src(), input_control_node); + } else { + const int index = e->dst_input(); + Node** id_node = &identity_nodes[index]; + if (*id_node != nullptr) { + LOG(ERROR) + << "RemoveListArrayConverter unexpected duplicated input: " + << e->dst_input(); + return removed_any; + } + *id_node = AddIdentity(g, {e->src(), e->src_output()}); + } + } + + // If node "n" has any control dependencies, the added identity + // nodes should have control dependencies on input_control_node. + if (input_control_node != nullptr) { + for (Node* id : identity_nodes) { + g->AddControlEdge(input_control_node, id); + } + } + + Node* output_control_node = nullptr; + for (const Edge* e : n->out_edges()) { + if (e->IsControlEdge()) { + if (output_control_node == nullptr) { + // If node "n" is control-depended upon by other nodes, + // adds a no-op node (output_control_node) which those + // nodes will depend on and output_control_node depends on + // all Identity nodes. + output_control_node = AddNoOp(g); + } + g->AddControlEdge(output_control_node, e->dst()); + } else { + Node* id_node = identity_nodes[e->src_output()]; + if (id_node == nullptr) { + LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: " + << e->src_output(); + return removed_any; + } + CHECK(id_node); + g->AddEdge(id_node, 0, e->dst(), e->dst_input()); + } + } + + // If any nodes have control dependencies on node "n", those + // nodes should have control dependencies on + // output_control_node. + if (output_control_node != nullptr) { + for (Node* id : identity_nodes) { + g->AddControlEdge(id, output_control_node); + } + } + + g->RemoveNode(n); + removed_any = true; + } + } + return removed_any; +} + +// Returns true iff the function '*fbody' can be inlined at 'node' +// based on the type signature of 'node' and 'fbody'. +static bool ValidateInlining(const Node* node, const FunctionBody* fbody) { + if (static_cast(node->num_inputs()) != fbody->arg_types.size()) { + return false; + } + if (static_cast(node->num_inputs()) != fbody->arg_nodes.size()) { + return false; + } + if (static_cast(node->num_outputs()) != fbody->ret_types.size()) { + return false; + } + if (static_cast(node->num_outputs()) != fbody->ret_nodes.size()) { + return false; + } + for (int i = 0; i < node->num_inputs(); ++i) { + if (node->input_type(i) != fbody->arg_types[i]) return false; + } + for (int i = 0; i < node->num_outputs(); ++i) { + if (node->output_type(i) != fbody->ret_types[i]) return false; + } + return true; +} + +// Given a "caller" in "graph", which is a function call of a function +// to "fbody". Replaces the "caller" with fbody->graph and connects +// edges properly. +static void InlineFunctionBody(Graph* g, Node* caller, + const FunctionBody* fbody) { + if (!ValidateInlining(caller, fbody)) { + LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. " + << DebugString(fbody->graph); + return; + } + + // Duplicate fbody->graph into 'g'. First, we copy the nodes of + // fbody->graph into 'g' except the source and sink nodes. We copy + // edges among nodes in 'fbody->graph'. + // + // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we + // remember 'y' in node_map[x->id()]. + std::vector node_map(fbody->graph->num_node_ids()); + for (Node* n : fbody->graph->nodes()) { + if (n->IsSource() || n->IsSink()) continue; + CHECK(n->IsOp()); + node_map[n->id()] = g->CopyNode(n); + } + for (const Edge* e : fbody->graph->edges()) { + if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() || + e->dst()->IsSink()) { + continue; + } + Node* src_copy = node_map[e->src()->id()]; + Node* dst_copy = node_map[e->dst()->id()]; + g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); + } + + // Connect input edges. + // + // For data edges coming into "caller", we first compute the + // : for the i-th input in "inputs". We create one + // Identity node for each input. Then, we connect inputs[i] to to + // the i-th identity node added. The nodes that previously connects + // to the j-th output of i-th arg node are reconnected to th i-th + // identity node. + // + // If "caller" has any input control dependencies, we add a NoOp + // node "input_control_node". This "input_control_node" depends on + // what "caller" depends on, and the added identity nodes depend on + // "input_control_node". + std::vector inputs(caller->num_inputs()); + Node* input_control_node = nullptr; + for (const Edge* e : caller->in_edges()) { + if (e->IsControlEdge()) { + if (input_control_node == nullptr) { + input_control_node = AddNoOp(g); + } + g->AddControlEdge(e->src(), input_control_node); + } else { + inputs[e->dst_input()] = {e->src(), e->src_output()}; + } + } + for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) { + Node* arg = node_map[fbody->arg_nodes[i]->id()]; + Node* n = AddIdentity(g, inputs[i]); + if (input_control_node) { + g->AddControlEdge(input_control_node, n); + } + for (const Edge* e : arg->out_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(n, e->dst()); + } else { + g->AddEdge(n, 0, e->dst(), e->dst_input()); + } + } + node_map[fbody->arg_nodes[i]->id()] = n; + g->RemoveNode(arg); // 'arg' is disconnected. + } + + // Connect output edges. + // + // For i-th return node in fbody->graph, we add in "g" an identity + // node (outputs[i-th]). We then reconnect every incoming edge into + // the i-th return node to the added identity node. + // + // For every data edge coming out of "callee"s i-th output, we + // reconnect it to the i-th identity added above. + // + // If "callee" is control-depended upon by any other nodes, we add a + // NoOp node "output_control_node". "output_control_node" depends on + // all identity nodes added above. And nodes previously depend on + // "callee" is changed to depend on "output_control_node". + std::vector outputs(caller->num_inputs()); + for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) { + Node* ret = node_map[fbody->ret_nodes[i]->id()]; + Endpoint data; // Data input for the ret node. + for (const Edge* e : ret->in_edges()) { + if (!e->IsControlEdge()) { + data = {e->src(), e->src_output()}; + break; + } + } + CHECK(data.node != nullptr); + Node* n = AddIdentity(g, data); + outputs[i] = n; + for (const Edge* e : ret->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), n); + } + } + g->RemoveNode(ret); // 'ret' is disconnected. + } + Node* output_control_node = nullptr; + for (const Edge* e : caller->out_edges()) { + if (e->IsControlEdge()) { + if (output_control_node == nullptr) { + output_control_node = AddNoOp(g); + for (Node* n : outputs) { + g->AddControlEdge(n, output_control_node); + } + } + g->AddControlEdge(output_control_node, e->dst()); + } else { + g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input()); + } + } + g->RemoveNode(caller); // 'caller' is replaced with inlined nodes. +} + +bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) { + std::vector> candidates; + for (Node* node : graph->nodes()) { + VLOG(3) << "Expanding " << node->DebugString(); + FunctionLibraryRuntime::Handle handle; + Status s = + lib->Instantiate(node->type_string(), node->def().attr(), &handle); + if (!s.ok()) { + // Either "node" is a primitive op, or the instantiation failed. + if (errors::IsNotFound(s)) { + VLOG(2) << "ExpandInlineFunctions " << s; + } else { + LOG(ERROR) << "ExpandInlineFunctions " << s; + } + continue; + } + const FunctionBody* fbody = lib->GetFunctionBody(handle); + CHECK_NOTNULL(fbody); + candidates.push_back({node, fbody}); + } + for (const auto& p : candidates) { + InlineFunctionBody(graph, p.first, p.second); + } + return !candidates.empty(); +} + +// TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef. +// and stash the original NodeDef name as an attr for documentation +// purpose. +static void ToGraphDef(const Graph* g, GraphDef* gdef) { + // We visit nodes in forward topological sort order, which is a + // possible execution order of the graph. + std::vector pending(g->num_node_ids()); + std::deque ready; + for (const Node* n : g->nodes()) { + pending[n->id()] = n->in_edges().size(); + if (pending[n->id()] == 0) ready.push_back(n); + } + gtl::InlinedVector inputs; + gdef->Clear(); + while (!ready.empty()) { + const Node* n = ready.front(); + ready.pop_front(); + for (const Edge* e : n->out_edges()) { + const Node* next = e->dst(); + if (--pending[next->id()] == 0) { + ready.push_back(next); + } + } + if (!n->IsOp()) continue; + NodeDef* ndef = gdef->add_node(); + ndef->set_name(strings::StrCat("n", n->id())); + ndef->set_op(n->type_string()); + *(ndef->mutable_attr()) = n->def().attr(); + inputs.clear(); + inputs.resize(n->num_inputs()); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + inputs.push_back(e); + } else { + if (inputs[e->dst_input()] == nullptr) { + inputs[e->dst_input()] = e; + } else { + LOG(WARNING) << "Malformed graph node. multiple input edges: " + << n->DebugString(); + } + } + } + // node->name() is merely NodeDef::name, which are not guaranteed + // to be unique and stable after optimization rewrites. Therefore, + // we use "n" instead. + for (const Edge* e : inputs) { + if (e == nullptr) { + ndef->add_input("unknown"); + } else if (!e->src()->IsOp()) { + } else if (e->IsControlEdge()) { + ndef->add_input(strings::StrCat("^n", e->src()->id())); + } else if (e->src_output() == 0) { + ndef->add_input(strings::StrCat("n", e->src()->id())); + } else { + ndef->add_input( + strings::StrCat("n", e->src()->id(), ":", e->src_output())); + } + } + } +} + +string DebugString(const Graph* g) { + GraphDef gdef; + ToGraphDef(g, &gdef); + return DebugString(gdef); +} + +FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t, + DataTypeSlice ret_t, Graph* g) + : fdef(f), + graph(g), + arg_types(arg_t.begin(), arg_t.end()), + ret_types(ret_t.begin(), ret_t.end()) { + this->arg_nodes.resize(arg_types.size()); + this->ret_nodes.resize(ret_types.size()); + for (Node* n : this->graph->nodes()) { + gtl::InlinedVector* node_vec; + if (n->type_string() == kRetOp) { + node_vec = &this->ret_nodes; + } else if (n->type_string() == kArgOp) { + node_vec = &this->arg_nodes; + } else { + continue; + } + int index; + TF_CHECK_OK(GetNodeAttr(n->def(), "index", &index)); + CHECK_LE(0, index); + CHECK_LT(index, node_vec->size()); + (*node_vec)[index] = n; + } +} + +FunctionBody::~FunctionBody() { delete this->graph; } + +class SymbolicGradientHelper { + public: + explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {} + + ~SymbolicGradientHelper() { delete gbody_; } + + FunctionBody* Compute(); + + private: + const FunctionBody* fbody_; + FunctionBody* gbody_ = nullptr; + + // A vector of output endpoints which represents backpropagated + // gradients + typedef std::vector BackpropedGradients; + + // backprops_ is a map from an output endpoint to its accumulated + // gradients. When an output endpoint has accumulated all its + // gradients, we add a node which sums them up. + std::unordered_map + backprops_; + + // pending[i] is count-down counter for i-th node's expected + // backprops. When pending[i] becomes zero, we collected all + // backprop gradients for all output endpoint of the ith-node. + std::vector pending_; + + // 'ready' keeps track of nodes that have been completely + // backpropped. Initially, for every output y of the function f, we + // add dy as an input of the the gradient function. + std::deque ready_; + + // Makes a copy of fbody_ in gbody_. + void Copy(); + + // Initialize pending_ and ready_. + void InitBackprop(); + + // In the original function body, there is a forward edge from 'src' + // to 'dst', when the backprop algorithm constructs the node + // 'dst_grad' which computes the gradient, we need to propagate it + // to 'src'. + void BackpropAlongEdge(const Endpoint& dst_grad, const Endpoint& src); + void BackpropZerosAlongEdge(const Endpoint& src); + + Endpoint SumGradients(const Endpoint& src); + + TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper); +}; + +void SymbolicGradientHelper::Copy() { + const Graph& src = *(fbody_->graph); + gbody_->graph = new Graph(src.op_registry()); + Graph* dst = gbody_->graph; + + std::vector node_map(src.num_node_ids()); + + // Copy the nodes. + node_map[src.source_node()->id()] = dst->source_node(); + node_map[src.sink_node()->id()] = dst->sink_node(); + for (Node* n : src.nodes()) { + if (n->IsSource() || n->IsSink()) continue; + CHECK(n->IsOp()); + node_map[n->id()] = dst->CopyNode(n); + } + + // Copy the edges. + for (const Edge* e : src.edges()) { + Node* src_copy = node_map[e->src()->id()]; + Node* dst_copy = node_map[e->dst()->id()]; + dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); + } + + // Save inputs in copied graph. + CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size()); + gbody_->arg_types = fbody_->arg_types; + for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) { + gbody_->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]); + } + + // Save outputs in copied graph. + CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size()); + gbody_->ret_types = fbody_->ret_types; + for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) { + gbody_->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]); + } +} + +void SymbolicGradientHelper::BackpropAlongEdge(const Endpoint& dst_grad, + const Endpoint& src) { + CHECK_NOTNULL(src.node); + auto iter = backprops_.find(src); + if (iter != backprops_.end()) { + auto* grads = &iter->second; + grads->push_back(dst_grad); + if (--pending_[src.node->id()] == 0) { + ready_.push_back(src.node); + } + } +} + +void SymbolicGradientHelper::BackpropZerosAlongEdge(const Endpoint& src) { + CHECK_NOTNULL(src.node); + auto iter = backprops_.find(src); + if (iter != backprops_.end()) { + if (--pending_[src.node->id()] == 0) { + ready_.push_back(src.node); + } + } +} + +void SymbolicGradientHelper::InitBackprop() { + Graph* g = gbody_->graph; + pending_.resize(g->num_node_ids(), 0); + { + backprops_.clear(); + std::unordered_set visited; + std::deque queue; + for (Node* n : gbody_->arg_nodes) { + queue.push_back(n); + } + + // Going forward to figure out which endpoints need backprop-ed. + // A node's endpoints need to be backprop-ed only if one of the + // arg node can reach the node via data edges. + while (!queue.empty()) { + Node* n = queue.front(); + queue.pop_front(); + visited.insert(n); + for (int i = 0; i < n->num_outputs(); ++i) { + backprops_[{n, i}].clear(); + } + int num_expected_backprops = 0; + for (const Edge* e : n->out_edges()) { + if (e->IsControlEdge()) continue; + ++num_expected_backprops; + if (visited.find(e->dst()) == visited.end()) { + queue.push_back(e->dst()); + } + } + pending_[n->id()] = num_expected_backprops; + } + } + + { + const int num_y = gbody_->ret_nodes.size(); + for (int i = 0; i < num_y; ++i) { + Node* y = gbody_->ret_nodes[i]; + DCHECK_EQ(y->type_string(), kRetOp); + const DataType dtype = y->input_type(0); + const int index = gbody_->arg_nodes.size(); + Node* dy = AddArg(g, dtype, index); + gbody_->arg_types.push_back(dtype); + gbody_->arg_nodes.push_back(dy); + + // What's the input to y? + Endpoint y_in{nullptr, 0}; + for (const Edge* e : y->in_edges()) { + if (!e->IsControlEdge()) { + y_in = {e->src(), e->src_output()}; + break; + } + } + CHECK_NOTNULL(y_in.node); + BackpropAlongEdge({dy, 0}, y_in); + } + } +} + +Endpoint SymbolicGradientHelper::SumGradients(const Endpoint& src) { + Graph* g = gbody_->graph; + const DataType dtype = src.dtype(); + auto iter = backprops_.find(src); + CHECK(iter != backprops_.end()); + const auto& grads = iter->second; + if (grads.empty()) { + // Nothing propagated back. The best we can come up is zeros. + Node* zero_like = AddZerosLike(g, src); + return {zero_like, 0}; + } + if (grads.size() == 1) { + // Just one backprop edge. + return grads[0]; + } + // Otherwise, adds backprop-ed gradients. + NodeDef ndef; + ndef.set_name(g->NewName(kNodeLabel)); + ndef.set_op("AddN"); // N-way Add + for (const Endpoint& ep : grads) { + ndef.add_input(ep.name()); + } + AddNodeAttr("N", static_cast(grads.size()), &ndef); + AddNodeAttr("T", dtype, &ndef); + Status s; + Node* add = gbody_->graph->AddNode(ndef, &s); + TF_CHECK_OK(s); + for (size_t i = 0; i < grads.size(); ++i) { + const Endpoint& ep = grads[i]; + g->AddEdge(ep.node, ep.index, add, i); + } + return {add, 0}; +} + +static bool IsPrimitiveOpWithNoGrad(const string& func) { + gradient::Creator creator; + Status s = gradient::GetOpGradientCreator(func, &creator); + return s.ok() && (creator == nullptr); +} + +FunctionBody* SymbolicGradientHelper::Compute() { + CHECK(gbody_ == nullptr); + gbody_ = new FunctionBody; + + // Copy fbody_ into gbody_. + Copy(); + + // Initialize backprops. + InitBackprop(); + + // Backward propagation. + gtl::InlinedVector dy; + Graph* g = gbody_->graph; + while (!ready_.empty()) { + // n has collected all gradients. + Node* n = ready_.front(); + ready_.pop_front(); + + if (n->type_string() == kArgOp) { + // We'll handle the _Arg node after backprop is done. + continue; + } + + // "n" has num_x inputs and num_y outputs. + const int num_x = n->num_inputs(); + const int num_y = n->num_outputs(); + + // dy[i] is the sum of i-th output's backpropped gradients. + dy.clear(); + dy.resize(num_y, {nullptr, 0}); + for (int i = 0; i < num_y; ++i) { + dy[i] = SumGradients({n, i}); + } + + if (IsPrimitiveOpWithNoGrad(n->type_string())) { + // No grad defined for this op. Backprops zeros along the in + // edges. + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) continue; + BackpropZerosAlongEdge({e->src(), e->src_output()}); + } + continue; + } + + // Adds a gradient node with num_x + num_y inputs and num_x + // outputs. + Node* grad = AddSymGrad(g, n, dy); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) continue; + g->AddEdge(e->src(), e->src_output(), grad, e->dst_input()); + } + for (int i = 0; i < num_y; ++i) { + g->AddEdge(dy[i].node, dy[i].index, grad, num_x + i); + } + + // Backprops along the in edges. + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) continue; + BackpropAlongEdge({grad, e->dst_input()}, {e->src(), e->src_output()}); + } + } + + // The gradient's retval nodes. + for (Node* n : gbody_->ret_nodes) { + g->RemoveNode(n); + } + gbody_->ret_types = fbody_->arg_types; + gbody_->ret_nodes.clear(); + for (size_t i = 0; i < fbody_->arg_types.size(); ++i) { + Endpoint grad = SumGradients({gbody_->arg_nodes[i], 0}); + Node* ret = AddRet(g, grad, i); + gbody_->ret_nodes.push_back(ret); + } + + auto ret = gbody_; + gbody_ = nullptr; + return ret; +} + +FunctionBody* SymbolicGradient(const FunctionBody& f) { + return SymbolicGradientHelper(f).Compute(); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h new file mode 100644 index 0000000000..634b31232a --- /dev/null +++ b/tensorflow/core/common_runtime/function.h @@ -0,0 +1,100 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_ +#define TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_ + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Creates a FunctionLibraryRuntime, which instantiates functions +// defined in "lib_def" and executes functions on the "device". +// +// The returned object does not take ownerships of "device" or +// "lib_def". The caller must ensure "device" and "lib_def" outlives +// the returned object. +typedef std::function Closure; +typedef std::function Runner; +FunctionLibraryRuntime* NewFunctionLibraryRuntime( + Device* device, Runner runner, const FunctionLibraryDefinition* lib_def); + +// FunctionLibraryRuntime::GetFunctionBody returns a description of an +// instantiated function that is represented as a Graph with arg/ret +// nodes annotated. +struct FunctionBody { + FunctionDef fdef; + Graph* graph = nullptr; // owned. + DataTypeVector arg_types; + DataTypeVector ret_types; + gtl::InlinedVector arg_nodes; + gtl::InlinedVector ret_nodes; + + FunctionBody() {} + FunctionBody(const FunctionDef& f, DataTypeSlice arg_types, + DataTypeSlice ret_types, Graph* g); + ~FunctionBody(); +}; + +// Debugging facility. Returns a debug string for a graph +// representing an instantiated function. +string DebugString(const Graph* instantiated_func_graph); + +// A few hand-crafted optimization on the instantiated function body +// (a Graph*). + +// Removes nodes that are +// 1. not stateful; and +// 2. not _Arg; and +// 3. not reachable from _Retval. +// Returns true iff any node is removed from "g". +bool RemoveDeadNodes(Graph* g); + +// Find a pattern: +// src -(in)-> node -(out)-> dst, where +// 1) node is an identity node; +// 2) in is the only incoming data edge; +// 3) out is the only outgoing data edge; +// +// Rewrites the above pattern with src->dst and relevant data +// dependencies updated. Repeat the process until no such pattern +// left. +bool RemoveIdentityNodes(Graph* g); + +// Rewrites _ListToArray and _ArrayToList to a set of Identity nodes. +bool RemoveListArrayConverter(Graph* g); + +// For each node in "graph", if "lib" indicates that the node is a +// function call, inline the function body. Returns true if at least +// one node is inlined. +// +// This routine goes through "graph" nodes once and applies the +// inlining. The caller may decide to apply the inlining on "graph" +// multiple times by calling ExpandInlineFunctions a few times. +bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph); + +// Applies graph rewrite optimzation such as inlining, dead code +// removal, etc. +// +// **g is a graph constructed based on the runtime library 'lib'. +// OptimizeGraph mutates **g extensively and replaces '*g' with a +// complete copy. Therefore, the caller should not keep any references +// to nodes *g. +void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g); + +// Given a numerical function "f", returns another numerical function +// "g", such that if "f" takes N inputs and produces M outputs, "g" +// takes N + M inputs and produces N outputs. I.e., if +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// g is a function which is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (...x_i...). +// +// TODO(zhifengc): Asks math expert to say the comment again. +FunctionBody* SymbolicGradient(const FunctionBody& f); + +} // end namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_ diff --git a/tensorflow/core/common_runtime/gpu/dma_helper.h b/tensorflow/core/common_runtime/gpu/dma_helper.h new file mode 100644 index 0000000000..7b0750f405 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/dma_helper.h @@ -0,0 +1,18 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_DMA_HELPER_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_DMA_HELPER_H_ + +#include "tensorflow/core/public/tensor.h" + +// For internal use only. Visibility should be limited to brain/framework. + +namespace tensorflow { +class DMAHelper { + public: + static bool CanUseDMA(const Tensor* t) { return t->CanUseDMA(); } + static const void* base(const Tensor* t) { return t->base(); } + static void* base(Tensor* t) { return t->base(); } + static TensorBuffer* buffer(Tensor* t) { return t->buf_; } + static const TensorBuffer* buffer(const Tensor* t) { return t->buf_; } +}; +} // namespace tensorflow +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_DMA_HELPER_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc new file mode 100644 index 0000000000..742459c63b --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc @@ -0,0 +1,49 @@ +#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +GPUAllocatorRetry::GPUAllocatorRetry() : env_(Env::Default()) {} + +void* GPUAllocatorRetry::AllocateRaw( + std::function alloc_func, + int max_millis_to_wait, size_t alignment, size_t num_bytes) { + if (num_bytes == 0) { + LOG(WARNING) << "Request to allocate 0 bytes"; + return nullptr; + } + uint64 deadline_micros = env_->NowMicros() + max_millis_to_wait * 1000; + void* ptr = nullptr; + while (ptr == nullptr) { + ptr = alloc_func(alignment, num_bytes, false); + if (ptr == nullptr) { + uint64 now = env_->NowMicros(); + if (now < deadline_micros) { + mutex_lock l(mu_); + WaitForMilliseconds(&l, &memory_returned_, + (deadline_micros - now) / 1000); + } else { + return alloc_func(alignment, num_bytes, true); + } + } + } + return ptr; +} + +void GPUAllocatorRetry::DeallocateRaw(std::function dealloc_func, + void* ptr) { + if (ptr == nullptr) { + LOG(ERROR) << "Request to free nullptr"; + return; + } + dealloc_func(ptr); + { + mutex_lock l(mu_); + memory_returned_.notify_all(); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h new file mode 100644 index 0000000000..a3298ab222 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h @@ -0,0 +1,36 @@ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ALLOCATOR_RETRY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ALLOCATOR_RETRY_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { + +// A retrying wrapper for a memory allocator. +class GPUAllocatorRetry { + public: + GPUAllocatorRetry(); + + // Call 'alloc_func' to obtain memory. On first call, + // 'verbose_failure' will be false. If return value is nullptr, + // then wait up to 'max_millis_to_wait' milliseconds, retrying each + // time a call to DeallocateRaw() is detected, until either a good + // pointer is returned or the deadline is exhausted. If the + // deadline is exahusted, try one more time with 'verbose_failure' + // set to true. The value returned is either the first good pointer + // obtained from 'alloc_func' or nullptr. + void* AllocateRaw(std::function alloc_func, + int max_millis_to_wait, size_t alignment, size_t bytes); + + // Calls dealloc_func(ptr) and then notifies any threads blocked in + // AllocateRaw() that would like to retry. + void DeallocateRaw(std::function dealloc_func, void* ptr); + + private: + Env* env_; + mutex mu_; + condition_variable memory_returned_; +}; +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ALLOCATOR_RETRY_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc new file mode 100644 index 0000000000..db1c58cc65 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc @@ -0,0 +1,175 @@ +#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h" + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/env.h" +#include + +namespace tensorflow { +namespace { + +class FakeAllocator { + public: + FakeAllocator(size_t cap, int millis_to_wait) + : memory_capacity_(cap), millis_to_wait_(millis_to_wait) {} + + // Allocate just keeps track of the number of outstanding allocations, + // not their sizes. Assume a constant size for each. + void* AllocateRaw(size_t alignment, size_t num_bytes) { + return retry_.AllocateRaw( + [this](size_t a, size_t nb, bool v) { + mutex_lock l(mu_); + if (memory_capacity_ > 0) { + --memory_capacity_; + return good_ptr_; + } else { + return static_cast(nullptr); + } + }, + millis_to_wait_, alignment, num_bytes); + } + + void DeallocateRaw(void* ptr) { + retry_.DeallocateRaw( + [this](void* p) { + mutex_lock l(mu_); + ++memory_capacity_; + }, + ptr); + } + + private: + GPUAllocatorRetry retry_; + void* good_ptr_ = reinterpret_cast(0xdeadbeef); + mutex mu_; + size_t memory_capacity_ GUARDED_BY(mu_); + int millis_to_wait_; +}; + +class GPUAllocatorRetryTest : public ::testing::Test { + protected: + GPUAllocatorRetryTest() {} + + void LaunchConsumerThreads(int num_consumers, int cap_needed) { + consumer_count_.resize(num_consumers, 0); + for (int i = 0; i < num_consumers; ++i) { + consumers_.push_back(Env::Default()->StartThread( + ThreadOptions(), "anon_thread", [this, i, cap_needed]() { + do { + void* ptr = nullptr; + for (int j = 0; j < cap_needed; ++j) { + ptr = alloc_->AllocateRaw(16, 1); + if (ptr == nullptr) { + mutex_lock l(mu_); + has_failed_ = true; + return; + } + } + ++consumer_count_[i]; + for (int j = 0; j < cap_needed; ++j) { + alloc_->DeallocateRaw(ptr); + } + } while (!notifier_.HasBeenNotified()); + })); + } + } + + // Wait up to wait_micros microseconds for has_failed_ to equal expected, + // then terminate all threads. + void JoinConsumerThreads(bool expected, int wait_micros) { + while (wait_micros > 0) { + { + mutex_lock l(mu_); + if (has_failed_ == expected) break; + } + int interval_micros = std::min(1000, wait_micros); + Env::Default()->SleepForMicroseconds(interval_micros); + wait_micros -= interval_micros; + } + notifier_.Notify(); + for (auto c : consumers_) { + // Blocks until thread terminates. + delete c; + } + } + + std::unique_ptr alloc_; + std::vector consumers_; + std::vector consumer_count_; + Notification notifier_; + mutex mu_; + bool has_failed_ GUARDED_BY(mu_) = false; + int count_ GUARDED_BY(mu_) = 0; +}; + +// Verifies correct retrying when memory is slightly overcommitted but +// we allow retry. +TEST_F(GPUAllocatorRetryTest, RetrySuccess) { + // Support up to 2 allocations simultaneously, waits up to 10 msec for + // a chance to alloc. + alloc_.reset(new FakeAllocator(2, 10000)); + // Launch 3 consumers, each of whom needs 1 unit at a time. + LaunchConsumerThreads(3, 1); + // This should be enough time for each consumer to be satisfied many times. + Env::Default()->SleepForMicroseconds(50000); + JoinConsumerThreads(false, 0); + for (int i = 0; i < 3; ++i) { + LOG(INFO) << "Consumer " << i << " is " << consumer_count_[i]; + } + { + mutex_lock l(mu_); + EXPECT_FALSE(has_failed_); + } + EXPECT_GT(consumer_count_[0], 0); + EXPECT_GT(consumer_count_[1], 0); + EXPECT_GT(consumer_count_[2], 0); +} + +// Verifies OutOfMemory failure when memory is slightly overcommitted +// and retry is not allowed. +TEST_F(GPUAllocatorRetryTest, NoRetryFail) { + // Support up to 2 allocations simultaneously, waits up to 0 msec for + // a chance to alloc. + alloc_.reset(new FakeAllocator(2, 0)); + // Launch 3 consumers, each of whom needs 1 unit at a time. + LaunchConsumerThreads(3, 1); + Env::Default()->SleepForMicroseconds(50000); + // Will wait up to 10 seconds for proper race condition to occur, resulting + // in failure. + JoinConsumerThreads(true, 10000000); + for (int i = 0; i < 3; ++i) { + LOG(INFO) << "Consumer " << i << " is " << consumer_count_[i]; + } + { + mutex_lock l(mu_); + EXPECT_TRUE(has_failed_); + } +} + +// Verifies OutOfMemory failure when retry is allowed but memory capacity +// is too low even for retry. +TEST_F(GPUAllocatorRetryTest, RetryInsufficientFail) { + // Support up to 2 allocations simultaneously, waits up to 10 msec for + // a chance to alloc. + alloc_.reset(new FakeAllocator(2, 10000)); + // Launch 3 consumers, each of whom needs 2 units at a time. We expect + // deadlock where 2 consumers each hold 1 unit, and timeout trying to + // get the second. + LaunchConsumerThreads(3, 2); + Env::Default()->SleepForMicroseconds(50000); + // Will wait up to 10 seconds for proper race condition to occur, resulting + // in failure. + JoinConsumerThreads(true, 10000000); + for (int i = 0; i < 3; ++i) { + LOG(INFO) << "Consumer " << i << " is " << consumer_count_[i]; + } + { + mutex_lock l(mu_); + EXPECT_TRUE(has_failed_); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc new file mode 100644 index 0000000000..3df833594f --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc @@ -0,0 +1,397 @@ +#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h" + +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/stream_executor.h" +#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h" +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { + +GPUBFCAllocator::GPUBFCAllocator(int device_id, size_t total_memory) + : device_id_(device_id) { + // Get a pointer to the stream_executor for this device + stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie(); + + // Allocate the requested amount of memory. + gpu_memory_size_ = total_memory; + + LOG(INFO) << "Allocating " << strings::HumanReadableNumBytes(gpu_memory_size_) + << " bytes."; + gpu::DeviceMemory gpu_mem = + stream_exec_->AllocateArray(gpu_memory_size_); + + QCHECK(gpu_mem != nullptr) + << " Could not allocate GPU device memory for device " << device_id + << ". Tried to allocate " + << strings::HumanReadableNumBytes(gpu_memory_size_); + base_ptr_ = gpu_mem.opaque(); + LOG(INFO) << "GPU " << device_id << " memory begins at " << base_ptr_ + << " extends to " + << static_cast( + (static_cast(base_ptr_) + gpu_memory_size_)); + + // Create a bunch of bins of various good sizes. + + // Covers allocations of exactly 256 bytes (the minimum size). + bins_.insert(std::make_pair(256, new Bin(256))); + + // We create bins to fit all possible ranges that cover the + // gpu_memory_size_ starting from allocations up to 1024 bytes to + // allocations up to (and including) the memory limit. + for (size_t bin_size = 1024; bin_size < gpu_memory_size_ * 2; bin_size *= 2) { + LOG(INFO) << "Creating bin of max chunk size " + << strings::HumanReadableNumBytes(bin_size); + bins_.insert(std::make_pair(bin_size, new Bin(bin_size))); + } + + // Create one large chunk for the whole memory space that will + // be chunked later. + GPUBFCAllocator::Chunk* c = new GPUBFCAllocator::Chunk(); + c->ptr = gpu_mem.opaque(); + c->size = gpu_memory_size_; + c->in_use = false; + c->prev = nullptr; + c->next = nullptr; + + ptr_to_chunk_map_.insert(std::make_pair(c->ptr, c)); + + // Insert the chunk into the right bin. + ReassignChunkToBin(c); +} + +GPUBFCAllocator::~GPUBFCAllocator() { + // Return memory back. + if (base_ptr_) { + gpu::DeviceMemoryBase gpu_ptr{base_ptr_}; + stream_exec_->Deallocate(&gpu_ptr); + } + + gtl::STLDeleteValues(&bins_); +} + +void* GPUBFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes) { + static const int64 kMaxMillisToWait = 10000; // 10 seconds + return retry_helper_.AllocateRaw( + [this](size_t a, size_t nb, bool v) { + return AllocateRawInternal(a, nb, v); + }, + kMaxMillisToWait, unused_alignment, num_bytes); +} + +void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment, + size_t num_bytes, + bool dump_log_on_failure) { + if (num_bytes == 0) { + LOG(ERROR) << "tried to allocate 0 bytes"; + return nullptr; + } + // First, always allocate memory of at least 256 bytes, and always + // allocate multiples of 256 bytes so all memory addresses are + // nicely byte aligned. + size_t rounded_bytes = (256 * ((num_bytes + 255) / 256)); + DCHECK_EQ(0, rounded_bytes % 256); + + // The BFC allocator tries to find the best fit first. + // + // First identify the first bin that could satisfy rounded_bytes. + auto it = bins_.lower_bound(rounded_bytes); + if (it == bins_.end()) { + LOG(ERROR) << " Asked for " << rounded_bytes << " but largest bin was " + << bins_.rbegin()->first; + return nullptr; + } + + mutex_lock l(lock_); + for (; it != bins_.end(); ++it) { + // Start searching from the first bin for the smallest chunk that fits + // rounded_bytes. + Bin* b = it->second; + for (GPUBFCAllocator::Chunk* chunk : b->chunks) { + if (!chunk->in_use && chunk->size > rounded_bytes) { + // We found an existing chunk that fits us that wasn't in use. + chunk->in_use = true; + + // If we can break the size of the chunk into two reasonably + // large pieces, do so. + // + // TODO(vrv): What should be the criteria when deciding when + // to split? + if (chunk->size >= rounded_bytes * 2) { + SplitChunk(chunk, rounded_bytes); + } + + // The requested size of the returned chunk is what the user + // has allocated. + chunk->requested_size = num_bytes; + + VLOG(4) << "Returning: " << chunk->ptr; + return chunk->ptr; + } + } + } + + // We searched all bins for an existing free chunk to use and + // couldn't find one. This means we must have run out of memory, + // Dump the memory log for analysis. + if (dump_log_on_failure) { + DumpMemoryLog(rounded_bytes); + LOG(WARNING) << "Ran out of memory trying to allocate " + << strings::HumanReadableNumBytes(num_bytes) + << ". See logs for memory state"; + } + return nullptr; +} + +void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) { + // Create a new chunk starting num_bytes after c + GPUBFCAllocator::Chunk* new_chunk = new GPUBFCAllocator::Chunk(); + new_chunk->ptr = static_cast(static_cast(c->ptr) + num_bytes); + VLOG(6) << "Adding to chunk map: " << new_chunk->ptr; + ptr_to_chunk_map_.insert(std::make_pair(new_chunk->ptr, new_chunk)); + + // Set the new sizes of the chunks. + new_chunk->size = c->size - num_bytes; + c->size = num_bytes; + + // The new chunk is not in use. + new_chunk->in_use = false; + + // Maintain the pointers. + // c <-> c_neighbor becomes + // c <-> new_chunk <-> c_neighbor + GPUBFCAllocator::Chunk* c_neighbor = c->next; + new_chunk->prev = c; + new_chunk->next = c_neighbor; + c->next = new_chunk; + if (c_neighbor) { + c_neighbor->prev = new_chunk; + } + + // Maintain the bins + ReassignChunkToBin(new_chunk); + ReassignChunkToBin(c); +} + +void GPUBFCAllocator::DeallocateRaw(void* ptr) { + retry_helper_.DeallocateRaw([this](void* p) { DeallocateRawInternal(p); }, + ptr); +} + +void GPUBFCAllocator::DeallocateRawInternal(void* ptr) { + if (ptr == nullptr) { + LOG(ERROR) << "tried to deallocate nullptr"; + return; + } + mutex_lock l(lock_); + + // Find the chunk from the ptr. + auto it = ptr_to_chunk_map_.find(ptr); + CHECK(it != ptr_to_chunk_map_.end()) + << "Asked to deallocate a pointer we never allocated: " << ptr; + + GPUBFCAllocator::Chunk* c = it->second; + VLOG(6) << "Chunk at " << c->ptr << " no longer in use"; + // Mark the chunk as no longer in use + c->in_use = false; + + // Consider coalescing it. + MaybeCoalesce(c); +} + +// Merges c1 and c2 when c1->next is c2 and c2->prev is c1. +// We merge c2 into c1. +void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1, + GPUBFCAllocator::Chunk* c2) { + // We can only merge chunks that are not in use. + DCHECK(!c1->in_use && !c2->in_use); + + // c1's prev doesn't change, still points to the same ptr, and is + // still not in use. + + // Fix up neighbor pointers + // + // c1 <-> c2 <-> c3 should become + // c1 <-> c3 + GPUBFCAllocator::Chunk* c3 = c2->next; + c1->next = c3; + CHECK(c2->prev == c1); + if (c3 != nullptr) { + c3->prev = c1; + } + + // Set the new size + c1->size += c2->size; + + // Delete c2 and cleanup all state + RemoveChunkFromBin(c2); +} + +void GPUBFCAllocator::ReassignChunkToBin(GPUBFCAllocator::Chunk* c) { + auto it = bins_.lower_bound(c->size); + CHECK(it != bins_.end()) << " Tried to reassign to non-existent bin for size " + << c->size; + + Bin* new_bin = it->second; + + // If the bin has not changed, do nothing. + Bin* old_bin = c->bin; + if (old_bin != nullptr && new_bin == old_bin) { + return; + } + + // The bin has changed. Add the chunk to the new bin and remove + // the chunk from the old bin. + new_bin->chunks.insert(c); + c->bin = new_bin; + + if (old_bin == nullptr) { + return; + } + + // Remove chunk from old bin + for (auto it = old_bin->chunks.begin(); it != old_bin->chunks.end(); ++it) { + if (*it == c) { + old_bin->chunks.erase(it); + return; + } + } + CHECK(false) << "Could not find chunk in old bin"; +} + +void GPUBFCAllocator::RemoveChunkFromBin(GPUBFCAllocator::Chunk* c) { + Bin* b = c->bin; + for (auto it = b->chunks.begin(); it != b->chunks.end(); ++it) { + Chunk* other_c = *it; + if (other_c->ptr == c->ptr) { + b->chunks.erase(it); + VLOG(4) << "Removing: " << c->ptr; + ptr_to_chunk_map_.erase(c->ptr); + delete c; + return; + } + } + + CHECK(false) << "Could not find chunk in bin"; +} + +void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) { + // This chunk is no longer in-use, consider coalescing the chunk + // with adjacent chunks. + Chunk* chunk_to_reassign = nullptr; + + // If the next chunk is free, coalesce the two, if the result would + // fit in an existing bin. + if (c->next && !c->next->in_use) { + VLOG(8) << "Chunk at " << c->next->ptr << " merging with c " << c->ptr; + + chunk_to_reassign = c; + + // Deletes c->next + Merge(c, c->next); + } + + // If the previous chunk is free, coalesce the two + if (c->prev && !c->prev->in_use) { + VLOG(8) << "Chunk at " << c->ptr << " merging into c->prev " + << c->prev->ptr; + + chunk_to_reassign = c->prev; + + // Deletes c + Merge(c->prev, c); + } + + // Reassign the final merged chunk into the right bin. + if (chunk_to_reassign) { + ReassignChunkToBin(chunk_to_reassign); + } +} + +void GPUBFCAllocator::AddAllocVisitor(Visitor visitor) { + VLOG(1) << "AddVisitor"; + mutex_lock l(lock_); + region_visitors_.push_back(visitor); + visitor(base_ptr_, gpu_memory_size_); +} + +bool GPUBFCAllocator::TracksAllocationSizes() { return true; } + +size_t GPUBFCAllocator::RequestedSize(void* ptr) { + mutex_lock l(lock_); + auto it = ptr_to_chunk_map_.find(ptr); + CHECK(it != ptr_to_chunk_map_.end()) + << "Asked for requested size of pointer we never allocated: " << ptr; + GPUBFCAllocator::Chunk* c = it->second; + return c->requested_size; +} + +size_t GPUBFCAllocator::AllocatedSize(void* ptr) { + mutex_lock l(lock_); + auto it = ptr_to_chunk_map_.find(ptr); + CHECK(it != ptr_to_chunk_map_.end()) + << "Asked for allocated size of pointer we never allocated: " << ptr; + GPUBFCAllocator::Chunk* c = it->second; + return c->size; +} + +void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) { + // For each bin: tally up the total number of chunks and bytes. + for (auto bit : bins_) { + Bin* b = bit.second; + + size_t total_bytes_in_use = 0; + size_t total_bytes_in_bin = 0; + size_t total_requested_bytes_in_use = 0; + size_t total_requested_bytes_in_bin = 0; + size_t total_chunks_in_use = 0; + size_t total_chunks_in_bin = 0; + for (Chunk* c : b->chunks) { + total_bytes_in_bin += c->size; + total_requested_bytes_in_bin += c->requested_size; + ++total_chunks_in_bin; + if (c->in_use) { + total_bytes_in_use += c->size; + total_requested_bytes_in_use += c->requested_size; + ++total_chunks_in_use; + } + } + + LOG(INFO) << "Bin (" << b->bin_size + << "): \tTotal Chunks: " << total_chunks_in_bin + << ", Chunks in use: " << total_chunks_in_use << " " + << strings::HumanReadableNumBytes(total_bytes_in_bin) + << " allocated for chunks. " + << strings::HumanReadableNumBytes(total_requested_bytes_in_bin) + << " client-requested for chunks. " + << strings::HumanReadableNumBytes(total_bytes_in_use) + << " in use in bin. " + << strings::HumanReadableNumBytes(total_requested_bytes_in_use) + << " client-requested in use in bin."; + } + + // Find the bin that we would have liked to allocate in, so we + // can get some further analysis about fragmentation. + auto it = bins_.lower_bound(num_bytes); + if (it != bins_.end()) { + Bin* b = it->second; + + LOG(INFO) << "Bin for " << strings::HumanReadableNumBytes(num_bytes) + << " was " << strings::HumanReadableNumBytes(b->bin_size) + << ", Chunk State: "; + + for (Chunk* c : b->chunks) { + LOG(INFO) << c->DebugString(true); + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h new file mode 100644 index 0000000000..3d1601e132 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h @@ -0,0 +1,156 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_ + +#include +#include +#include +#include + +#include "tensorflow/stream_executor/stream_executor.h" +#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h" +#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// A GPU memory allocator that implements a 'best-fit with coalescing' +// algorithm. This is essentially a very simple version of Doug Lea's +// malloc (dlmalloc). +// +// The goal of this allocator is to support defragmentation via +// coalescing. One assumption we make is that the process using this +// allocator owns pretty much all of the GPU memory, and that nearly +// all requests to allocate GPU memory go through this interface. +class GPUBFCAllocator : public VisitableAllocator { + public: + // 'device_id' refers to the StreamExecutor ID of the device within + // the process and must reference a valid ID in the process. + explicit GPUBFCAllocator(int device_id, size_t total_memory); + ~GPUBFCAllocator() override; + + string Name() override { return "gpu_bfc"; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + + void AddAllocVisitor(Visitor visitor) override; + + // Does nothing, because gpu memory is never freed. + void AddFreeVisitor(Visitor visitor) override {} + + bool TracksAllocationSizes() override; + + size_t RequestedSize(void* ptr) override; + + size_t AllocatedSize(void* ptr) override; + + private: + struct Bin; + + void* AllocateRawInternal(size_t alignment, size_t num_bytes, + bool dump_log_on_failure); + void DeallocateRawInternal(void* ptr); + + // Chunks point to GPU memory. Their prev/next pointers form a + // doubly-linked list of addresses sorted by GPU base address that + // must be contiguous. Chunks contain information about whether + // they are in use or whether they are free, and contain a pointer + // to the bin they are in. + struct Chunk { + size_t size = 0; // Full size of GPU buffer. + + // We sometimes give chunks that are larger than needed to reduce + // fragmentation. requested_size keeps track of what the client + // actually wanted so we can understand whether our splitting + // strategy is efficient. + size_t requested_size = 0; + + bool in_use = false; + void* ptr = nullptr; // pointer to granted GPU subbuffer. + + // If not null, the memory referred to by 'prev' is directly + // preceding the memory used by this chunk. E.g., It should start + // at 'ptr - prev->size' + Chunk* prev = nullptr; + + // If not null, the memory referred to by 'next' is directly + // following the memory used by this chunk. E.g., It should be at + // 'ptr + size' + Chunk* next = nullptr; + + // What bin are we in? + Bin* bin = nullptr; + + string DebugString(bool recurse) { + string dbg; + strings::StrAppend(&dbg, " Size: ", strings::HumanReadableNumBytes(size), + " | Requested Size: ", + strings::HumanReadableNumBytes(requested_size), + " | in_use: ", in_use); + if (recurse && prev) { + strings::StrAppend(&dbg, ", prev: ", prev->DebugString(false)); + } + if (recurse && next) { + strings::StrAppend(&dbg, ", next: ", next->DebugString(false)); + } + return dbg; + } + }; + + Chunk* AllocateNewChunk(size_t num_bytes); + void SplitChunk(Chunk* c, size_t num_bytes); + void Merge(Chunk* c1, Chunk* c2); + void MaybeCoalesce(Chunk* c); + + void ReassignChunkToBin(Chunk* c); + void RemoveChunkFromBin(Chunk* c); + + void DumpMemoryLog(size_t num_bytes); + + // A Bin is a collection of similar-sized Chunks. + struct Bin { + // All chunks in this bin have >= bin_size memory. + size_t bin_size = 0; + + struct ChunkComparator { + bool operator()(Chunk* a, Chunk* b) { return a->size < b->size; } + }; + + // List of chunks within the bin, sorted by chunk size. + std::multiset chunks; + + explicit Bin(size_t bs) : bin_size(bs) {} + + ~Bin() { gtl::STLDeleteElements(&chunks); } + }; + + GPUAllocatorRetry retry_helper_; + + // Structures immutable after construction + const int device_id_; + // The base pointer where all the GPU memory begins. + void* base_ptr_ = nullptr; + size_t gpu_memory_size_ = 0; + + // Map from bin size to Bin + // After construction, the bin map is never resized. + std::map bins_; + + perftools::gputools::StreamExecutor* stream_exec_; // Not owned. + + // Structures mutable after construction + mutable mutex lock_; + // Not owned. + std::unordered_map ptr_to_chunk_map_; + + // Called once on each region, ASAP. + std::vector region_visitors_; + + TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc new file mode 100644 index 0000000000..7b5e8aec1d --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc @@ -0,0 +1,166 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h" + +#include +#include + +#include "tensorflow/stream_executor/stream_executor.h" +#include +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { +namespace { + +TEST(GPUBFCAllocatorTest, NoDups) { + GPUBFCAllocator a(0, 1 << 30); + // Allocate a lot of raw pointers + std::vector ptrs; + for (int s = 1; s < 1024; s++) { + void* raw = a.AllocateRaw(1, s); + ptrs.push_back(raw); + } + + std::sort(ptrs.begin(), ptrs.end()); + + // Make sure none of them are equal, and that none of them overlap. + for (int i = 0; i < ptrs.size(); i++) { + if (i > 0) { + ASSERT_NE(ptrs[i], ptrs[i - 1]); // No dups + size_t req_size = a.RequestedSize(ptrs[i - 1]); + ASSERT_GT(req_size, 0); + ASSERT_GE(static_cast(ptrs[i]) - static_cast(ptrs[i - 1]), + req_size); + } + } + + for (int i = 0; i < ptrs.size(); i++) { + a.DeallocateRaw(ptrs[i]); + } +} + +TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) { + GPUBFCAllocator a(0, 1 << 30); + // Allocate 256 raw pointers of sizes between 100 bytes and about + // a meg + random::PhiloxRandom philox(123, 17); + random::SimplePhilox rand(&philox); + + std::vector initial_ptrs; + for (int s = 1; s < 256; s++) { + size_t size = std::min( + std::max(rand.Rand32() % 1048576, 100), 1048576); + void* raw = a.AllocateRaw(1, size); + + initial_ptrs.push_back(raw); + } + + // Deallocate half of the memory, and keep track of the others. + std::vector existing_ptrs; + for (int i = 0; i < initial_ptrs.size(); i++) { + if (i % 2 == 1) { + a.DeallocateRaw(initial_ptrs[i]); + } else { + existing_ptrs.push_back(initial_ptrs[i]); + } + } + + // Allocate a lot of raw pointers + for (int s = 1; s < 256; s++) { + size_t size = std::min( + std::max(rand.Rand32() % 1048576, 100), 1048576); + void* raw = a.AllocateRaw(1, size); + existing_ptrs.push_back(raw); + } + + std::sort(existing_ptrs.begin(), existing_ptrs.end()); + // Make sure none of them are equal + for (int i = 0; i < existing_ptrs.size(); i++) { + if (i > 0) { + CHECK_NE(existing_ptrs[i], existing_ptrs[i - 1]); // No dups + + size_t req_size = a.RequestedSize(existing_ptrs[i - 1]); + ASSERT_GT(req_size, 0); + + // Check that they don't overlap. + ASSERT_GE(static_cast(existing_ptrs[i]) - + static_cast(existing_ptrs[i - 1]), + req_size); + } + } + + for (int i = 0; i < existing_ptrs.size(); i++) { + a.DeallocateRaw(existing_ptrs[i]); + } +} + +TEST(GPUBFCAllocatorTest, ExerciseCoalescing) { + GPUBFCAllocator a(0, 1 << 30); + + float* first_ptr = a.Allocate(1024); + a.Deallocate(first_ptr); + for (int i = 0; i < 1024; ++i) { + // Allocate several buffers of different sizes, and then clean them + // all up. We should be able to repeat this endlessly without + // causing fragmentation and growth. + float* t1 = a.Allocate(1024); + + int64* t2 = a.Allocate(1048576); + double* t3 = a.Allocate(2048); + float* t4 = a.Allocate(10485760); + + a.Deallocate(t1); + a.Deallocate(t2); + a.Deallocate(t3); + a.Deallocate(t4); + } + + // At the end, we should have coalesced all memory into one region + // starting at the beginning, so validate that allocating a pointer + // starts from this region. + float* first_ptr_after = a.Allocate(1024); + EXPECT_EQ(first_ptr, first_ptr_after); + a.Deallocate(first_ptr_after); +} + +TEST(GPUBFCAllocatorTest, AllocateZeroBufSize) { + GPUBFCAllocator a(0, 1 << 30); + float* ptr = a.Allocate(0); + EXPECT_EQ(nullptr, ptr); +} + +TEST(GPUBFCAllocatorTest, TracksSizes) { + GPUBFCAllocator a(0, 1 << 30); + EXPECT_EQ(true, a.TracksAllocationSizes()); +} + +TEST(GPUBFCAllocatorTest, AllocatedVsRequested) { + GPUBFCAllocator a(0, 1 << 30); + float* t1 = a.Allocate(1); + EXPECT_EQ(4, a.RequestedSize(t1)); + EXPECT_EQ(256, a.AllocatedSize(t1)); + a.Deallocate(t1); +} + +TEST(GPUBFCAllocatorTest, TestCustomMemoryLimit) { + // Configure a 1MiB byte limit + GPUBFCAllocator a(0, 1 << 20); + + float* first_ptr = a.Allocate(1 << 6); + float* second_ptr = a.Allocate(1 << 20); + + EXPECT_NE(nullptr, first_ptr); + EXPECT_EQ(nullptr, second_ptr); + a.Deallocate(first_ptr); +} + +} // namespace +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc new file mode 100644 index 0000000000..5ec405cd80 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc @@ -0,0 +1,186 @@ +#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h" + +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/stream_executor.h" + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { + +#define MASK_WORDS 2 +#define MASK_BYTES (MASK_WORDS * sizeof(int64)) + +namespace { + +static int64* NewMask(int64 word) { + int64* m = new int64[MASK_WORDS]; + for (int i = 0; i < MASK_WORDS; ++i) { + m[i] = word; + } + return m; +} + +static int64* before_mask = NewMask(0xabababababababab); +static int64* after_mask = NewMask(0xcdcdcdcdcdcdcdcd); + +bool CheckMask(perftools::gputools::StreamExecutor* exec, void* ptr, + int64* mask) { + gpu::DeviceMemory gpu_ptr{gpu::DeviceMemoryBase{ptr, MASK_BYTES}}; + int64 tmp[MASK_WORDS]; + + if (!exec->SynchronousMemcpy(&tmp, gpu_ptr, MASK_BYTES)) { + LOG(FATAL) << "Could not copy debug mask"; + } + + bool ok = true; + for (int i = 0; i < MASK_WORDS; ++i) { + ok &= (mask[i] == tmp[i]); + if (!ok) { + LOG(ERROR) << "i=" << i + << " mask=" << reinterpret_cast(mask[i]) + << " field=" << reinterpret_cast(tmp[i]); + } + } + + return ok; +} + +void InitMask(perftools::gputools::StreamExecutor* exec, void* ptr, + int64* mask) { + gpu::DeviceMemory gpu_ptr{gpu::DeviceMemoryBase{ptr, MASK_BYTES}}; + if (!exec->SynchronousMemcpy(&gpu_ptr, mask, MASK_BYTES)) { + LOG(FATAL) << "Could not copy debug mask"; + } +} + +} // namespace + +// ----------------------------------------------------------------------------- +// GPUDebugAllocator +// ----------------------------------------------------------------------------- +GPUDebugAllocator::GPUDebugAllocator(VisitableAllocator* allocator, + int device_id) + : base_allocator_(allocator) { + stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie(); +} + +GPUDebugAllocator::~GPUDebugAllocator() { delete base_allocator_; } + +void* GPUDebugAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { + num_bytes += (2 * MASK_BYTES); + + void* allocated_ptr = base_allocator_->AllocateRaw(alignment, num_bytes); + + // Return the pointer after the header + void* rv = static_cast(allocated_ptr) + MASK_BYTES; + + // Write the header at allocated_ptr + InitMask(stream_exec_, allocated_ptr, before_mask); + + // Write the footer at the end. + size_t req_size = base_allocator_->RequestedSize(allocated_ptr); + InitMask(stream_exec_, + static_cast(allocated_ptr) + req_size - MASK_BYTES, + after_mask); + return rv; +} +void GPUDebugAllocator::DeallocateRaw(void* ptr) { + CHECK(CheckHeader(ptr)) << "before_mask has been overwritten"; + CHECK(CheckFooter(ptr)) << "after_mask has been overwritten"; + + // Backtrack to the beginning of the header. + ptr = static_cast(static_cast(ptr) - MASK_BYTES); + // Deallocate the memory + base_allocator_->DeallocateRaw(ptr); +} + +void GPUDebugAllocator::AddAllocVisitor(Visitor visitor) { + return base_allocator_->AddAllocVisitor(visitor); +} + +void GPUDebugAllocator::AddFreeVisitor(Visitor visitor) { + return base_allocator_->AddFreeVisitor(visitor); +} + +bool GPUDebugAllocator::TracksAllocationSizes() { return true; } + +size_t GPUDebugAllocator::RequestedSize(void* ptr) { + auto req_size = + base_allocator_->RequestedSize(static_cast(ptr) - MASK_BYTES); + return req_size - 2 * MASK_BYTES; +} + +size_t GPUDebugAllocator::AllocatedSize(void* ptr) { + return base_allocator_->AllocatedSize(static_cast(ptr) - MASK_BYTES); +} + +bool GPUDebugAllocator::CheckHeader(void* ptr) { + return CheckMask(stream_exec_, static_cast(ptr) - MASK_BYTES, + before_mask); +} + +bool GPUDebugAllocator::CheckFooter(void* ptr) { + char* original_ptr = static_cast(ptr) - MASK_BYTES; + size_t req_size = base_allocator_->RequestedSize(original_ptr); + return CheckMask(stream_exec_, original_ptr + req_size - MASK_BYTES, + after_mask); +} + +// ----------------------------------------------------------------------------- +// GPUNanResetAllocator +// ----------------------------------------------------------------------------- +GPUNanResetAllocator::GPUNanResetAllocator(VisitableAllocator* allocator, + int device_id) + : base_allocator_(allocator) { + stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie(); +} + +GPUNanResetAllocator::~GPUNanResetAllocator() { delete base_allocator_; } + +void* GPUNanResetAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { + void* allocated_ptr = base_allocator_->AllocateRaw(alignment, num_bytes); + + // Initialize the buffer to Nans + size_t req_size = base_allocator_->RequestedSize(allocated_ptr); + std::vector nans(req_size / sizeof(float), std::nanf("")); + gpu::DeviceMemory nan_ptr{ + gpu::DeviceMemoryBase{static_cast(allocated_ptr), req_size}}; + + if (!stream_exec_->SynchronousMemcpy(&nan_ptr, &nans[0], req_size)) { + LOG(ERROR) << "Could not initialize to NaNs"; + } + + return allocated_ptr; +} +void GPUNanResetAllocator::DeallocateRaw(void* ptr) { + // Reset the buffer to Nans + size_t req_size = base_allocator_->RequestedSize(ptr); + std::vector nans(req_size / sizeof(float), std::nanf("")); + gpu::DeviceMemory nan_ptr{ + gpu::DeviceMemoryBase{static_cast(ptr), req_size}}; + if (!stream_exec_->SynchronousMemcpy(&nan_ptr, &nans[0], req_size)) { + LOG(ERROR) << "Could not initialize to NaNs"; + } + + // Deallocate the memory + base_allocator_->DeallocateRaw(ptr); +} + +void GPUNanResetAllocator::AddAllocVisitor(Visitor visitor) { + return base_allocator_->AddAllocVisitor(visitor); +} + +void GPUNanResetAllocator::AddFreeVisitor(Visitor visitor) { + return base_allocator_->AddFreeVisitor(visitor); +} + +size_t GPUNanResetAllocator::RequestedSize(void* ptr) { + return base_allocator_->RequestedSize(ptr); +} + +size_t GPUNanResetAllocator::AllocatedSize(void* ptr) { + return base_allocator_->AllocatedSize(ptr); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h new file mode 100644 index 0000000000..c9b564ffc4 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h @@ -0,0 +1,68 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h" +#include "tensorflow/stream_executor/stream_executor.h" + +namespace tensorflow { + +// An allocator that wraps a GPU allocator and adds debugging +// functionality that verifies that users do not write outside their +// allocated memory. +class GPUDebugAllocator : public VisitableAllocator { + public: + explicit GPUDebugAllocator(VisitableAllocator* allocator, int device_id); + ~GPUDebugAllocator() override; + string Name() override { return "gpu_debug"; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + void AddAllocVisitor(Visitor visitor) override; + void AddFreeVisitor(Visitor visitor) override; + bool TracksAllocationSizes() override; + size_t RequestedSize(void* ptr) override; + size_t AllocatedSize(void* ptr) override; + + // For testing. + bool CheckHeader(void* ptr); + bool CheckFooter(void* ptr); + + private: + VisitableAllocator* base_allocator_ = nullptr; // owned + + perftools::gputools::StreamExecutor* stream_exec_; // Not owned. + + TF_DISALLOW_COPY_AND_ASSIGN(GPUDebugAllocator); +}; + +// An allocator that wraps a GPU allocator and resets the memory on +// allocation and free to 'NaN', helping to identify cases where the +// user forgets to initialize the memory. +class GPUNanResetAllocator : public VisitableAllocator { + public: + explicit GPUNanResetAllocator(VisitableAllocator* allocator, int device_id); + ~GPUNanResetAllocator() override; + string Name() override { return "gpu_nan_reset"; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + void AddAllocVisitor(Visitor visitor) override; + void AddFreeVisitor(Visitor visitor) override; + size_t RequestedSize(void* ptr) override; + size_t AllocatedSize(void* ptr) override; + + private: + VisitableAllocator* base_allocator_ = nullptr; // owned + + perftools::gputools::StreamExecutor* stream_exec_; // Not owned. + + TF_DISALLOW_COPY_AND_ASSIGN(GPUNanResetAllocator); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc new file mode 100644 index 0000000000..5f63906576 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc @@ -0,0 +1,207 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h" + +#include +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/stream_executor.h" +#include + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { + +TEST(GPUDebugAllocatorTest, OverwriteDetection_None) { + const int device_id = 0; + GPUDebugAllocator a(new GPUBFCAllocator(device_id, 1 << 30), device_id); + auto stream_exec = + GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie(); + + for (int s : {8}) { + std::vector cpu_array(s); + memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64)); + int64* gpu_array = a.Allocate(cpu_array.size()); + gpu::DeviceMemory gpu_array_ptr{gpu::DeviceMemoryBase{gpu_array}}; + ASSERT_TRUE(stream_exec->SynchronousMemcpy(&gpu_array_ptr, &cpu_array[0], + s * sizeof(int64))); + EXPECT_TRUE(a.CheckHeader(gpu_array)); + EXPECT_TRUE(a.CheckFooter(gpu_array)); + + // Confirm no error on free. + a.DeallocateRaw(gpu_array); + } +} + +TEST(GPUDebugAllocatorTest, OverwriteDetection_Header) { + for (int s : {8, 211}) { + EXPECT_DEATH( + { + const int device_id = 0; + GPUDebugAllocator a(new GPUBFCAllocator(device_id, 1 << 30), + device_id); + auto stream_exec = + GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie(); + + std::vector cpu_array(s); + memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64)); + int64* gpu_array = a.Allocate(cpu_array.size()); + + gpu::DeviceMemory gpu_array_ptr{ + gpu::DeviceMemoryBase{gpu_array}}; + ASSERT_TRUE(stream_exec->SynchronousMemcpy( + &gpu_array_ptr, &cpu_array[0], cpu_array.size() * sizeof(int64))); + + gpu::DeviceMemory gpu_hdr_ptr{ + gpu::DeviceMemoryBase{gpu_array - 1}}; + // Clobber first word of the header. + float pi = 3.1417; + ASSERT_TRUE( + stream_exec->SynchronousMemcpy(&gpu_hdr_ptr, &pi, sizeof(float))); + + // Expect error on free. + a.Deallocate(gpu_array); + }, + ""); + } +} + +TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) { + for (int s : {8, 22}) { + EXPECT_DEATH( + { + const int device_id = 0; + GPUDebugAllocator a(new GPUBFCAllocator(device_id, 1 << 30), + device_id); + auto stream_exec = + GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie(); + + std::vector cpu_array(s); + memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64)); + int64* gpu_array = a.Allocate(cpu_array.size()); + + gpu::DeviceMemory gpu_array_ptr{ + gpu::DeviceMemoryBase{gpu_array}}; + ASSERT_TRUE(stream_exec->SynchronousMemcpy( + &gpu_array_ptr, &cpu_array[0], cpu_array.size() * sizeof(int64))); + + // Clobber word of the footer. + gpu::DeviceMemory gpu_ftr_ptr{ + gpu::DeviceMemoryBase{gpu_array + s}}; + float pi = 3.1417; + ASSERT_TRUE( + stream_exec->SynchronousMemcpy(&gpu_ftr_ptr, &pi, sizeof(float))); + + // Expect error on free. + a.Deallocate(gpu_array); + }, + ""); + } +} + +TEST(GPUDebugAllocatorTest, ResetToNan) { + const int device_id = 0; + GPUNanResetAllocator a(new GPUBFCAllocator(device_id, 1 << 30), device_id); + auto stream_exec = + GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie(); + + std::vector cpu_array(1024); + std::vector cpu_array_result(1024); + + // Allocate 1024 floats + float* gpu_array = a.Allocate(cpu_array.size()); + gpu::DeviceMemory gpu_array_ptr{gpu::DeviceMemoryBase{gpu_array}}; + ASSERT_TRUE(stream_exec->SynchronousMemcpy(&cpu_array[0], gpu_array_ptr, + cpu_array.size() * sizeof(float))); + for (float f : cpu_array) { + ASSERT_FALSE(std::isfinite(f)); + } + + // Set one of the fields to 1.0. + cpu_array[0] = 1.0; + ASSERT_TRUE(stream_exec->SynchronousMemcpy(&gpu_array_ptr, &cpu_array[0], + cpu_array.size() * sizeof(float))); + // Copy the data back and verify. + ASSERT_TRUE( + stream_exec->SynchronousMemcpy(&cpu_array_result[0], gpu_array_ptr, + cpu_array_result.size() * sizeof(float))); + ASSERT_EQ(1.0, cpu_array_result[0]); + + // Free the array + a.Deallocate(gpu_array); + + // All values should be reset to nan. + ASSERT_TRUE( + stream_exec->SynchronousMemcpy(&cpu_array_result[0], gpu_array_ptr, + cpu_array_result.size() * sizeof(float))); + for (float f : cpu_array_result) { + ASSERT_FALSE(std::isfinite(f)); + } +} + +TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) { + const int device_id = 0; + // NaN reset must be the outer-most allocator. + GPUNanResetAllocator a( + new GPUDebugAllocator(new GPUBFCAllocator(device_id, 1 << 30), device_id), + device_id); + auto stream_exec = + GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie(); + + std::vector cpu_array(1024); + std::vector cpu_array_result(1024); + + // Allocate 1024 floats + float* gpu_array = a.Allocate(cpu_array.size()); + gpu::DeviceMemory gpu_array_ptr{gpu::DeviceMemoryBase{gpu_array}}; + ASSERT_TRUE(stream_exec->SynchronousMemcpy(&cpu_array[0], gpu_array_ptr, + cpu_array.size() * sizeof(float))); + for (float f : cpu_array) { + ASSERT_FALSE(std::isfinite(f)); + } + + // Set one of the fields to 1.0. + cpu_array[0] = 1.0; + ASSERT_TRUE(stream_exec->SynchronousMemcpy(&gpu_array_ptr, &cpu_array[0], + cpu_array.size() * sizeof(float))); + // Copy the data back and verify. + ASSERT_TRUE( + stream_exec->SynchronousMemcpy(&cpu_array_result[0], gpu_array_ptr, + cpu_array_result.size() * sizeof(float))); + ASSERT_EQ(1.0, cpu_array_result[0]); + + // Free the array + a.Deallocate(gpu_array); + + // All values should be reset to nan. + ASSERT_TRUE( + stream_exec->SynchronousMemcpy(&cpu_array_result[0], gpu_array_ptr, + cpu_array_result.size() * sizeof(float))); + for (float f : cpu_array_result) { + ASSERT_FALSE(std::isfinite(f)); + } +} + +TEST(GPUDebugAllocatorTest, TracksSizes) { + GPUDebugAllocator a(new GPUBFCAllocator(0, 1 << 30), 0); + EXPECT_EQ(true, a.TracksAllocationSizes()); +} + +TEST(GPUDebugAllocatorTest, AllocatedVsRequested) { + GPUNanResetAllocator a( + new GPUDebugAllocator(new GPUBFCAllocator(0, 1 << 30), 0), 0); + float* t1 = a.Allocate(1); + EXPECT_EQ(4, a.RequestedSize(t1)); + EXPECT_EQ(256, a.AllocatedSize(t1)); + a.Deallocate(t1); +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc new file mode 100644 index 0000000000..26d34645f1 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -0,0 +1,651 @@ +// TODO(opensource): Use a more generic sounding preprocessor name than +// GOOGLE_CUDA +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/common_runtime/gpu/gpu_device.h" + +#include +#include + +//#include "base/commandlineflags.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/stream_executor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h" +#include "tensorflow/core/common_runtime/gpu/gpu_util.h" +#include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/device_name_utils.h" + +#if defined(PLATFORM_GOOGLE) +DEFINE_bool(brain_gpu_sync_every_op, false, + "If true, call GPUUtil::Sync() between every dispatched opkernel."); + +DEFINE_int32(brain_gpu_max_streams, 1, + "Max number of GPU streams to use for computation."); +#else +// TODO(opensource): These should be made options in some options struct, +// rather than flags. +bool FLAGS_brain_gpu_sync_every_op = false; +tensorflow::int32 FLAGS_brain_gpu_max_streams = 1; +#endif + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { + +// Eigen Ops directly allocate memory only for temporary buffers used +// during OpKernel::Compute(). The recommended way of allocating such +// memory is via OpKernelContext::allocate_temp(). However, Eigen Ops +// don't have access to OpKernelContext, instead they get access to +// memory directly through the device allocator. As an Open Source +// project, Eigen assumes allocator semantics similar to those of the +// CUDA memory allocator, and may not work correctly due to race +// conditions if used with some other allocator. For safety, we need +// to delay deallocation calls out of Eigen until all events on the +// corresponding stream have completed. The following two classes +// serve this purpose in two different compilation environments. + +#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__) +class EigenAllocator : public ::Eigen::Allocator { + public: + explicit EigenAllocator(gpu::Stream* stream, ::tensorflow::Allocator* alloc, + EventMgr* em) + : stream_(stream), allocator_(alloc), em_(em) {} + + void* allocate(size_t num_bytes) const override { + void* ret = allocator_->AllocateRaw(32 /* alignment */, num_bytes); + // Eigen doesn't typically check the return pointer from allocate, + // so we do it here and die with a more helpful error message. + if (ret == nullptr) { + LOG(FATAL) << "EigenAllocator for GPU ran out of memory when allocating " + << num_bytes << ". See error logs for more detailed info."; + } + return ret; + } + + void deallocate(void* buffer) const override { + em_->ThenDeleteBuffer(stream_, {allocator_, buffer}); + } + + private: + gpu::Stream* stream_; // Not owned. + ::tensorflow::Allocator* allocator_; // Not owned. + ::tensorflow::EventMgr* em_; // Not owned. + + TF_DISALLOW_COPY_AND_ASSIGN(EigenAllocator); +}; + +#else +class EigenCudaStreamDevice : public ::Eigen::StreamInterface { + public: + EigenCudaStreamDevice(const cudaStream_t* cuda_stream, int gpu_id, + ::tensorflow::Allocator* alloc) + : stream_(cuda_stream), allocator_(alloc) { + Eigen::initializeDeviceProp(); + device_prop_ = &Eigen::m_deviceProperties[gpu_id]; + } + + const cudaStream_t& stream() const override { return *stream_; } + const cudaDeviceProp& deviceProperties() const override { + return *device_prop_; + } + + void* allocate(size_t num_bytes) const override { + void* ret = allocator_->AllocateRaw(32 /* alignment */, num_bytes); + if (ret == nullptr) { + LOG(FATAL) << "EigenAllocator for GPU ran out of memory when allocating " + << num_bytes << ". See error logs for more detailed info."; + } + + return ret; + } + void deallocate(void* buffer) const override { + AsyncFreeData* afData = new AsyncFreeData(allocator_, buffer); + cudaError_t err = cudaStreamAddCallback(*stream_, asyncFree, afData, 0); + CHECK_EQ(err, cudaSuccess); + } + + private: + struct AsyncFreeData { + AsyncFreeData(::tensorflow::Allocator* a, void* p) + : allocator_(a), address_(p) {} + ::tensorflow::Allocator* allocator_; + void* address_; + }; + + static void CUDART_CB asyncFree(cudaStream_t stream, cudaError_t status, + void* userData) { + AsyncFreeData* data = static_cast(userData); + data->allocator_->DeallocateRaw(data->address_); + delete data; + } + + const cudaStream_t* stream_; // Not owned. + const cudaDeviceProp* device_prop_; // Not owned. + ::tensorflow::Allocator* allocator_; // Not owned. + + TF_DISALLOW_COPY_AND_ASSIGN(EigenCudaStreamDevice); +}; + +#endif + +BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name, + Bytes memory_limit, BusAdjacency bus_adjacency, + int gpu_id, const string& physical_device_desc, + Allocator* gpu_allocator, Allocator* cpu_allocator) + : LocalDevice(options, Device::BuildDeviceAttributes( + name, DEVICE_GPU, memory_limit, bus_adjacency, + physical_device_desc), + gpu_allocator), + gpu_allocator_(gpu_allocator), + cpu_allocator_(cpu_allocator), + gpu_id_(gpu_id) { + gpu::StreamExecutor* executor = + GPUMachineManager()->ExecutorForDevice(gpu_id_).ValueOrDie(); + if (!executor) { + LOG(ERROR) << "Failed to get StreamExecutor for device " << gpu_id_; + return; + } + em_.reset(new EventMgr(executor)); + + if (FLAGS_brain_gpu_max_streams < 1) { + LOG(FATAL) << "Invalid value for brain_gpu_max_streams."; + } + + // Create the specified number of GPU streams + for (int i = 0; i < FLAGS_brain_gpu_max_streams; i++) { + auto stream = new gpu::Stream(executor); + stream->Init(); + VLOG(2) << "Created stream[" << i << "] = " << stream; + streams_.push_back(stream); + device_contexts_.push_back(new GPUDeviceContext(i, stream)); + } + gpu_device_info_ = new GpuDeviceInfo; + gpu_device_info_->stream = streams_[0]; + gpu_device_info_->default_context = device_contexts_[0]; + gpu_device_info_->event_mgr = em_.get(); + set_tensorflow_gpu_device_info(gpu_device_info_); +} + +BaseGPUDevice::~BaseGPUDevice() { + delete gpu_device_info_; + for (auto ctx : device_contexts_) ctx->Unref(); + gtl::STLDeleteElements(&streams_); +} + +Status BaseGPUDevice::FillContextMap(const Graph* graph, + DeviceContextMap* device_context_map) { + VLOG(2) << "FillContextMap"; + + const auto num_streams = streams_.size(); + // Special case for single stream. + if (num_streams == 1) { + return Status::OK(); + } + const int64 before = Env::Default()->NowMicros(); + gpu_stream_util::AssignStreamsOpts opts; + opts.max_streams = num_streams; + std::unordered_map node_to_stream_id; + TF_RETURN_IF_ERROR( + gpu_stream_util::AssignStreams(graph, opts, &node_to_stream_id)); + int64 elapsed = Env::Default()->NowMicros() - before; + VLOG(3) << "AssignStreams took " << elapsed << "us"; + + // Fill in the context map. It is OK for this map to contain + // duplicate DeviceContexts so long as we increment the refcount. + for (Node* n : graph->nodes()) { + auto mapped_stream = node_to_stream_id[n->id()]; + CHECK_LE(mapped_stream, num_streams); + auto ctx = device_contexts_[mapped_stream]; + VLOG(3) << "Assigned stream " << node_to_stream_id[n->id()] + << " ==> stream[" << ctx->stream_id() << "] for node id " << n->id() + << " " << n->type_string() << " " << n->name(); + ctx->Ref(); + device_context_map->insert(std::make_pair(n->id(), ctx)); + } + + return Status::OK(); +} + +void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { + // ScopedActivity is cheap when tracing is not active, but we + // can avoid computing the Hash64. + // TODO(pbar) This would no longer be needed if Ops have a unique id. + const uint64 id = port::Tracing::IsActive() ? Hash64(op_kernel->name()) : 0; + port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute, + id); + + GPUDeviceContext* gpu_device_context = device_contexts_[0]; + if (context->op_device_context() != nullptr) { + gpu_device_context = + static_cast(context->op_device_context()); + } + gpu::Stream* stream = gpu_device_context->stream(); + const auto stream_id = gpu_device_context->stream_id(); + + VLOG(1) << "GpuDevice::Compute " << op_kernel->name() << " op " + << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream[" + << stream_id << "]"; + + // NOTE(tucker): We need to discriminate between Eigen GPU + // operations and all others. If an operation is Eigen + // implemented (or otherwise tries to launch a cuda kernel + // directly), we need to establish a stacked-scoped environment + // that directs it to execute on the proper device. Otherwise we + // expect the Op to use StreamExecutor directly and correctly. The + // way we make this discrimination is quite hacky: At the moment + // the only non-Eigen GPU Op is the recv-op, which is known to be + // asynchronous. + if (op_kernel->type_string() == "_Recv") { + context->SetStatus(errors::Internal( + "Invalid synchronous 'Compute' on GPU for '_Recv' op")); + } else { + const string label = + strings::StrCat(op_kernel->name(), ":", op_kernel->type_string()); + port::Tracing::ScopedAnnotation annotation(label); + + const auto num_streams = streams_.size(); + if (num_streams > 1) { + // If this op's device context is different from the other contexts, + // we must wait on the stream. + for (int i = 0; i < context->num_inputs(); ++i) { + const GPUDeviceContext* idc = + static_cast(context->input_device_context(i)); + OP_REQUIRES(context, idc != nullptr, + errors::Internal("Input device context ", i, + " was not set properly.")); + if (VLOG_IS_ON(2)) { + const void* base; + size_t len; + if (context->has_input(i)) { + if (IsRefType(context->input_dtype(i))) { + Tensor tensor = context->mutable_input(i, false); + base = DMAHelper::base(&tensor); + len = tensor.TotalBytes(); + } else { + const Tensor& tensor = context->input(i); + base = DMAHelper::base(&tensor); + len = tensor.TotalBytes(); + } + VLOG(2) << "Input " << i << " " << base << " " << len; + VLOG(2) << " stream[" << stream_id << "].ThenWaitFor(stream[" + << idc->stream_id() << "])" + << ((idc->stream() == stream) ? " not needed" : ""); + } + } + if (idc->stream() != stream) stream->ThenWaitFor(idc->stream()); + } + } + gpu::cuda::ScopedActivateExecutorContext scoped_activation{ + stream->parent(), gpu::cuda::MultiOpActivation::kYes}; + // Keep a copy of the inputs before Compute runs, in case they get + // deleted. TODO(misard) this will be fixed when the tracking is + // done right. + std::vector* tensor_refs = nullptr; + if (!FLAGS_brain_gpu_sync_every_op) { + tensor_refs = new std::vector; + tensor_refs->reserve(context->num_inputs() + context->num_outputs()); + for (int ii = 0; ii < context->num_inputs(); ++ii) { + if (context->has_input(ii)) { + if (IsRefType(context->input_dtype(ii))) { + Tensor in = context->mutable_input(ii, false); + tensor_refs->push_back(in); + } else { + const Tensor& in = context->input(ii); + tensor_refs->push_back(in); + } + } + } + } + op_kernel->Compute(context); + if (context->status().ok()) { + if (FLAGS_brain_gpu_sync_every_op) { + // Note: GPUUtil::Sync() only syncs the default stream. + // We need to either sync the stream used by this op, or + // all streams. Given that this flag is typically used for + // debugging it makes more sense to sync all GPU activity. + context->SetStatus(GPUUtil::SyncAll(this)); + } else { + // The GPU kernel has been queued, but may not complete for some + // time. As soon as this function completes, the caller will + // discard its refs on the inputs, outputs and any scratch + // tensors it created. Create additional refs here that will be + // held until the kernel completes. + for (int ii = 0; ii < context->num_temps(); ++ii) { + Tensor* temp = context->temp(ii); + VLOG(2) << "Saving ref to temp Tensor @ " << DMAHelper::base(temp); + tensor_refs->push_back(*temp); + } + for (int ii = 0; ii < context->num_outputs(); ++ii) { + Tensor* temp = context->mutable_output(ii); + if (nullptr != temp) { + tensor_refs->push_back(*temp); + } + } + em_->ThenDeleteTensors(stream, tensor_refs); + } + } else { + if (!FLAGS_brain_gpu_sync_every_op) { + delete tensor_refs; + } + } + } +} + +Status BaseGPUDevice::Sync() { return GPUUtil::Sync(this); } + +void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel, + OpKernelContext* context, + AsyncOpKernel::DoneCallback done) { + GPUDeviceContext* gpu_device_context = device_contexts_[0]; + if (context->op_device_context() != nullptr) { + gpu_device_context = + static_cast(context->op_device_context()); + } + const auto stream_id = gpu_device_context->stream_id(); + + VLOG(1) << "GpuDevice::ComputeAsync " << op_kernel->name() << " op " + << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream[" + << stream_id << "]"; + + port::Tracing::TraceMe activity( + strings::StrCat(op_kernel->name(), ":", op_kernel->type_string())); + op_kernel->ComputeAsync(context, done); +} + +Status BaseGPUDevice::MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) { + AllocatorAttributes attr; + attr.set_on_host(true); + attr.set_gpu_compatible(true); + Allocator* host_alloc = GetAllocator(attr); + Tensor parsed(tensor_proto.dtype()); + if (!parsed.FromProto(host_alloc, tensor_proto)) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + tensor_proto.DebugString()); + } + Status status; + if (alloc_attrs.on_host()) { + *tensor = parsed; + } else { + if (!DMAHelper::CanUseDMA(&parsed)) { + return errors::Internal("GPU copy from non-DMA ", + DataTypeString(parsed.dtype()), " tensor"); + } + Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); + port::Tracing::ScopedAnnotation annotation("MakeTensorFromProto"); + Notification n; + device_contexts_[0]->CopyCPUTensorToDevice(&parsed, this, ©, + [&n, &status](const Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + *tensor = copy; + } + return status; +} + +namespace { +#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__) +class ConcretePerOpGpuDevice : public PerOpGpuDevice { + public: + explicit ConcretePerOpGpuDevice(gpu::Stream* stream, + EigenAllocator* allocator) + : device_(stream, allocator), allocator_(allocator) {} + ~ConcretePerOpGpuDevice() { delete allocator_; } + + const Eigen::GpuDevice& device() const override { return device_; } + + private: + Eigen::GpuDevice device_; + EigenAllocator* allocator_; +}; +#else +class ConcretePerOpGpuDevice : public PerOpGpuDevice { + public: + explicit ConcretePerOpGpuDevice(EigenCudaStreamDevice* stream_device) + : device_(stream_device), stream_device_(stream_device) {} + ~ConcretePerOpGpuDevice() { delete stream_device_; } + + const Eigen::GpuDevice& device() const override { return device_; } + + private: + Eigen::GpuDevice device_; + EigenCudaStreamDevice* stream_device_; +}; +#endif +} // namespace + +const PerOpGpuDevice* BaseGPUDevice::NewDevice(int stream_id, + Allocator* allocator) { +#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__) + auto ea = new EigenAllocator(streams_[stream_id], allocator, em_.get()); + return new ConcretePerOpGpuDevice(streams_[stream_id], ea); +#else + const cudaStream_t* cuda_stream = reinterpret_cast( + streams_[stream_id]->implementation()->CudaStreamMemberHack()); + auto es = new EigenCudaStreamDevice(cuda_stream, gpu_id_, allocator); + return new ConcretePerOpGpuDevice(es); +#endif +} + +const PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice(DeviceContext* dc, + Allocator* allocator) { + if (dc) { + const GPUDeviceContext* gpu_dc = static_cast(dc); + const int stream_id = gpu_dc->stream_id(); + VLOG(1) << " eigen_gpu_device(" << dc << ") => stream[" << stream_id + << "]"; + CHECK_LT(stream_id, streams_.size()); + return NewDevice(stream_id, allocator); + } else { + return NewDevice(0, allocator); + } +} + +void BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, + const string& name_prefix, + std::vector* devices) { + int n = INT_MAX; + auto iter = options.config.device_count().find("GPU"); + if (iter != options.config.device_count().end()) { + n = iter->second; + } + std::vector valid_gpu_ids; + GetValidDeviceIds(&valid_gpu_ids); + if (static_cast(n) > valid_gpu_ids.size()) { + n = valid_gpu_ids.size(); + } + for (int i = 0; i < n; i++) { + devices->push_back(CreateGPUDevice( + options, strings::StrCat(name_prefix, "/gpu:", i), valid_gpu_ids[i])); + } +} + +namespace { +int64 MinSystemMemory(int64 available_memory) { + // We use the following heuristic for now: + // + // If the available_memory is < 2GiB, we allocate 200MiB to system memory. + // Otherwise, allocate 300MiB to system memory. + // + // In the future we could be more sophisticated by using a table of + // devices. + if (available_memory < (1LL << 31)) { + // 200MiB + return 209715200LL; + } else { + // max(300 MiB, 0.95 * available_memory) + return std::max(314572800LL, static_cast(available_memory * 0.05)); + } +} +} // namespace + +static string GetShortDeviceDescription(int device_id, + const gpu::DeviceDescription& desc) { + return strings::StrCat("device: ", device_id, ", name: ", desc.name(), + ", pci bus id: ", desc.pci_bus_id()); +} + +LocalDevice* BaseGPUDeviceFactory::CreateGPUDevice( + const SessionOptions& options, const string& name, int gpu_id) { + CHECK_GE(gpu_id, 0); + + // Look up the device, to see its attributes. + gpu::Platform* gpu_platform = GPUMachineManager(); + CHECK_LT(gpu_id, gpu_platform->VisibleDeviceCount()); + gpu::StreamExecutor* se = + gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie(); + const gpu::DeviceDescription& desc = se->GetDeviceDescription(); + + int64 total_memory, available_memory; + CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory)); + + int64 allocated_memory = available_memory; + double config_memory_fraction = + options.config.gpu_options().per_process_gpu_memory_fraction(); + if (config_memory_fraction == 0) { + const int64 min_system_memory = MinSystemMemory(available_memory); + if (min_system_memory < allocated_memory) { + allocated_memory -= min_system_memory; + } + } else { + allocated_memory *= config_memory_fraction; + } + + Bytes allocated_bytes = static_cast(allocated_memory); + + // Get GPU BusAdjacency from its reported NUMA affinity. + // Because GPUs are virtualized in some environments, we can't just + // use the GPU id. + BusAdjacency bus_adjacency = BUS_ANY; + switch (desc.numa_node()) { + case 0: + bus_adjacency = BUS_0; + break; + case 1: + bus_adjacency = BUS_1; + break; + default: + bus_adjacency = BUS_ANY; + } + VLOG(1) << "GPUDevice id " << gpu_id << " on bus " << bus_adjacency + << " numa: " << desc.numa_node() << " pci: " << desc.pci_bus_id(); + + ProcessState* process_state = ProcessState::singleton(); + return CreateGPUDevice( + options, name, allocated_bytes, bus_adjacency, gpu_id, + GetShortDeviceDescription(gpu_id, desc), + process_state->GetGPUAllocator(gpu_id, allocated_memory), + process_state->GetCPUAllocator(desc.numa_node())); +} + +static int GetMinGPUMultiprocessorCount() { + static const int kDefaultMinGPUMultiprocessorCount = 8; + + const char* tf_min_gpu_core_count = getenv("TF_MIN_GPU_MULTIPROCESSOR_COUNT"); + + if (tf_min_gpu_core_count == nullptr || + strcmp(tf_min_gpu_core_count, "") == 0) { + return kDefaultMinGPUMultiprocessorCount; + } + + int min_gpu_core_count = -1; + if (strings::safe_strto32(tf_min_gpu_core_count, &min_gpu_core_count)) { + if (min_gpu_core_count >= 0) { + return min_gpu_core_count; + } + } + + LOG(ERROR) << "Invalid minimum GPU multiprocessor count: [" + << tf_min_gpu_core_count << "]. " + << "Using the default value: " + << kDefaultMinGPUMultiprocessorCount; + return kDefaultMinGPUMultiprocessorCount; +} + +void BaseGPUDeviceFactory::GetValidDeviceIds(std::vector* ids) { + auto gpu_manager = GPUMachineManager(); + int min_gpu_core_count = GetMinGPUMultiprocessorCount(); + if (gpu_manager) { + auto visible_device_count = gpu_manager->VisibleDeviceCount(); + for (int i = 0; i < gpu_manager->VisibleDeviceCount(); ++i) { + auto exec_status = gpu_manager->ExecutorForDevice(i); + if (!exec_status.ok()) { + continue; + } + gpu::StreamExecutor* se = exec_status.ValueOrDie(); + const gpu::DeviceDescription& desc = se->GetDeviceDescription(); + int major, minor; + if (!desc.cuda_compute_capability(&major, &minor)) { + continue; + } + // Only consider GPUs with compute capability >= 3.5 (Kepler or + // higher) + if (major < 3 || (major == 3 && minor < 5)) { + LOG(INFO) << "Ignoring gpu device " + << "(" << GetShortDeviceDescription(i, desc) << ") " + << "with Cuda compute capability " << major << "." << minor + << ". The minimum required Cuda capability is 3.5."; + continue; + } + + // TensorFlow currently places computation on devices assuming + // they have similar capability. + // + // If there are multiple GPUs available on the machine, only + // consider GPUs with 8 or more multiprocessors. + // + // TODO(vrv): In the medium term: we should only filter out GPUs + // that are slow relative to the fastest GPU. In the long term, + // TensorFlow should support automatic placement based on + // capability. + if (visible_device_count > 1) { + if (desc.core_count() < min_gpu_core_count) { + LOG(INFO) << "Ignoring gpu device " + << "(" << GetShortDeviceDescription(i, desc) << ") " + << "with Cuda multiprocessor count: " << desc.core_count() + << ". The minimum required count is " << min_gpu_core_count + << ". You can adjust this requirement with the env var " + "TF_MIN_GPU_MULTIPROCESSOR_COUNT."; + continue; + } + } + + int new_id = ids->size(); + ids->push_back(i); + + LOG(INFO) << "Creating TensorFlow device (/gpu:" << new_id << ") -> " + << "(" << GetShortDeviceDescription(i, desc) << ")"; + } + } +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h new file mode 100644 index 0000000000..a415224d95 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_device.h @@ -0,0 +1,94 @@ +#if !GOOGLE_CUDA +#error This file must only be included when building with Cuda support +#endif + +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_ + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/stream_executor/stream.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +class EigenAllocator; + +class BaseGPUDevice : public LocalDevice { + public: + BaseGPUDevice(const SessionOptions& options, const string& name, + Bytes memory_limit, BusAdjacency bus_adjacency, int gpu_id, + const string& physical_device_desc, Allocator* gpu_allocator, + Allocator* cpu_allocator); + + ~BaseGPUDevice() override; + + // GPU devices require the Op Compute method to save a reference to + // any temporary tensors that are allocated until the Op execution + // completes. + bool SaveTemporaryTensors() const override { return true; } + + Status FillContextMap(const Graph* graph, + DeviceContextMap* device_context_map); + + void Compute(OpKernel* op_kernel, OpKernelContext* context) override; + + Status Sync() override; + + void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; + + Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override; + + // The caller owns the returned device. + const PerOpGpuDevice* MakeGpuDevice(DeviceContext* dc, + Allocator* allocator) override; + + protected: + Allocator* gpu_allocator_; // not owned + Allocator* cpu_allocator_; // not owned + + private: + std::vector streams_; + std::vector device_contexts_; + GpuDeviceInfo* gpu_device_info_ = nullptr; + mutex trace_mu_; + int gpu_id_ = -1; + std::unique_ptr em_; + + const PerOpGpuDevice* NewDevice(int stream_id, Allocator* allocator); +}; + +class BaseGPUDeviceFactory : public DeviceFactory { + public: + void CreateDevices(const SessionOptions& options, const string& name_prefix, + std::vector* devices) override; + + private: + LocalDevice* CreateGPUDevice(const SessionOptions& options, + const string& name, int gpu_id); + + virtual LocalDevice* CreateGPUDevice(const SessionOptions& options, + const string& name, Bytes memory_limit, + BusAdjacency bus_adjacency, int gpu_id, + const string& physical_device_desc, + Allocator* gpu_allocator, + Allocator* cpu_allocator) = 0; + + void GetValidDeviceIds(std::vector* ids); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc new file mode 100644 index 0000000000..240ac47499 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc @@ -0,0 +1,52 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/common_runtime/gpu/gpu_device.h" +#include "tensorflow/core/common_runtime/gpu/process_state.h" + +namespace tensorflow { + +void RequireGPUDevice() {} + +class GPUDevice : public BaseGPUDevice { + public: + GPUDevice(const SessionOptions& options, const string& name, + Bytes memory_limit, BusAdjacency bus_adjacency, int gpu_id, + const string& physical_device_desc, Allocator* gpu_allocator, + Allocator* cpu_allocator) + : BaseGPUDevice(options, name, memory_limit, bus_adjacency, gpu_id, + physical_device_desc, gpu_allocator, cpu_allocator) {} + + Allocator* GetAllocator(AllocatorAttributes attr) override { + if (attr.on_host()) { + ProcessState* ps = ProcessState::singleton(); + if (attr.gpu_compatible()) { + return ps->GetCUDAHostAllocator(0); + } else { + return cpu_allocator_; + } + } else { + return gpu_allocator_; + } + } +}; + +class GPUDeviceFactory : public BaseGPUDeviceFactory { + private: + LocalDevice* CreateGPUDevice(const SessionOptions& options, + const string& name, Bytes memory_limit, + BusAdjacency bus_adjacency, int gpu_id, + const string& physical_device_desc, + Allocator* gpu_allocator, + Allocator* cpu_allocator) override { + return new GPUDevice(options, name, memory_limit, bus_adjacency, gpu_id, + physical_device_desc, gpu_allocator, cpu_allocator); + } +}; + +REGISTER_LOCAL_DEVICE_FACTORY("GPU", GPUDeviceFactory); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc new file mode 100644 index 0000000000..29d6281733 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc @@ -0,0 +1,132 @@ +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" + +#include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/stream.h" + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { + +EventMgr::EventMgr(gpu::StreamExecutor* se) + : exec_(se), + // threadpool_ has 1 thread for the polling loop, and one to execute + // event callback functions. Maybe we should have more? + threadpool_(Env::Default(), "GPU_Event_Manager", 2) { + threadpool_.Schedule([this]() { PollLoop(); }); +} + +EventMgr::~EventMgr() { + stop_polling_.Notify(); + // Shut down the backup polling loop. + polling_stopped_.WaitForNotification(); + + // Events are owned by this object. + for (auto& e : free_events_) { + delete e; + } + while (!used_events_.empty()) { + delete used_events_[0].event; + delete used_events_[0].mem; + if (used_events_[0].bufrec.buf) { + used_events_[0].bufrec.alloc->DeallocateRaw(used_events_[0].bufrec.buf); + } + if (used_events_[0].func != nullptr) + threadpool_.Schedule(used_events_[0].func); + used_events_.pop_front(); + } +} + +// This polling loop runs at a relatively low frequency. Most calls to +// PollEvents() should come directly from Compute() via +// ThenDeleteTensors(). This function's purpose is to ensure that +// even if no more GPU operations are being requested, we still +// eventually clear the queue. It seems to prevent some tensorflow +// programs from stalling for reasons not yet understood. +void EventMgr::PollLoop() { + while (!stop_polling_.HasBeenNotified()) { + Env::Default()->SleepForMicroseconds(1 * 1000); + { + mutex_lock l(mu_); + PollEvents(true); + } + } + polling_stopped_.Notify(); +} + +void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu) { + VLOG(2) << "QueueInUse free_events_ " << free_events_.size() + << " used_events_ " << used_events_.size(); + // Events are created on demand, and repeatedly reused. There is no + // limit placed here on the number of allocated Events. + if (free_events_.empty()) { + free_events_.push_back(new gpu::Event(exec_)); + free_events_.back()->Init(); + } + gpu::Event* e = free_events_.back(); + free_events_.pop_back(); + stream->ThenRecordEvent(e); + iu.event = e; + used_events_.push_back(iu); +} + +// This function must be called periodically to check whether pending +// events have recorded, and then retire them. Initial observations +// suggest that typical behavior in a TensorFlow program is to have +// 0-3 events pending most of the time, but there are occasionally +// spikes of up to several hundred outstanding. +// +// NOTE: If all events are on the same stream, no later event will +// complete before an earlier event, except possibly if the earlier +// event transitions to an error state, so there's no advantage in +// looking past the first kPending event. However, if we're using +// multiple streams there may be some gain in looking deeper. +// As a compromise, PollEvent() calls that are triggered by the queueing +// of a single event never look past the first kPending event. Calls +// coming from the dedicated polling thread always sweep the full queue. +// +// Note that allowing the queue to grow very long could cause overall +// GPU memory use to spike needlessly. An alternative strategy would +// be to throttle new Op execution until the pending event queue +// clears. +void EventMgr::PollEvents(bool is_dedicated_poller) { + VLOG(2) << "PollEvents free_events_ " << free_events_.size() + << " used_events_ " << used_events_.size(); + // Sweep the remaining events in order. If this is the dedicated + // polling thread, check the entire set. Otherwise, just sweep up to + // the first non-complete record that is still pending. + for (auto& iu : used_events_) { + if (iu.event == nullptr) continue; + gpu::Event::Status s = iu.event->PollForStatus(); + switch (s) { + case gpu::Event::Status::kUnknown: + case gpu::Event::Status::kError: + // We don't expect to see these. Someday maybe propagate + // a Status error, but for now fail hard. + LOG(FATAL) << "Unexpected Event status: " << static_cast(s); + break; + case gpu::Event::Status::kPending: + if (!is_dedicated_poller) return; // quit processing queue + break; + case gpu::Event::Status::kComplete: + delete iu.mem; + if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf); + // The function must be called in another thread, outside of + // the mutex held here. + if (iu.func != nullptr) threadpool_.Schedule(iu.func); + free_events_.push_back(iu.event); + // Mark this InUse record as completed. + iu.event = nullptr; + } + } + // Then clear any completed InUse records from the front of the queue. + while (!used_events_.empty()) { + InUse& iu = used_events_.front(); + if (iu.event == nullptr) { + used_events_.pop_front(); + } else { + break; + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h new file mode 100644 index 0000000000..f9436566d4 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h @@ -0,0 +1,118 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_ + +#include +#include +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/tensor.h" + +namespace perftools { +namespace gputools { +class Event; +class Stream; +class StreamExecutor; +} // namespace gputools +} // namespace perftools + +namespace tensorflow { + +// An object to keep track of pending Events in the StreamExecutor streams +// and associated Tensors that cannot safely be deleted until the associated +// Events are recorded. +class EventMgr { + public: + explicit EventMgr(perftools::gputools::StreamExecutor* se); + + ~EventMgr(); + + // Takes ownership of *tensors and deletes it as soon as all events + // currently enqueued on *stream have completed. + inline void ThenDeleteTensors(perftools::gputools::Stream* stream, + std::vector* tensors) { + mutex_lock l(mu_); + QueueTensors(stream, tensors); + PollEvents(false); + } + + struct BufRec { + Allocator* alloc; + void* buf; + }; + + // Takes ownership of *bufrec.buf and calls bufrec.alloc->DeallocateRaw() + // on it as soon as all events currently enqueued on *stream have completed. + inline void ThenDeleteBuffer(perftools::gputools::Stream* stream, + BufRec bufrec) { + mutex_lock l(mu_); + QueueBuffer(stream, bufrec); + PollEvents(false); + } + + inline void ThenExecute(perftools::gputools::Stream* stream, + std::function func) { + mutex_lock l(mu_); + QueueFunc(stream, func); + PollEvents(false); + } + + private: + friend class TEST_EventMgrHelper; + mutex mu_; + perftools::gputools::StreamExecutor* exec_; + + struct InUse { + perftools::gputools::Event* event; + std::vector* mem; + BufRec bufrec; + std::function func; + }; + + // Stream-enqueue an unused Event and save with it a collection of + // Tensors and/or a BufRec to be deleted only after the Event + // records. + void QueueInUse(perftools::gputools::Stream* stream, InUse in_use) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + void QueueTensors(perftools::gputools::Stream* stream, + std::vector* tensors) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + QueueInUse(stream, {nullptr, tensors, BufRec(), nullptr}); + } + + void QueueBuffer(perftools::gputools::Stream* stream, BufRec bufrec) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + QueueInUse(stream, {nullptr, nullptr, bufrec, nullptr}); + } + + void QueueFunc(perftools::gputools::Stream* stream, + std::function func) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + QueueInUse(stream, {nullptr, nullptr, BufRec(), func}); + } + + // This function should be called at roughly the same tempo as + // QueueTensors() to check whether pending events have recorded, + // and then retire them. + void PollEvents(bool is_dedicated_poller) EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // An internal polling loop that runs at a low frequency to clear + // straggler Events. + void PollLoop(); + + // A stack of unused events + std::vector free_events_ GUARDED_BY(mu_); + + // A FIFO queue of InUse events and associated tensors. + std::deque used_events_ GUARDED_BY(mu_); + + Notification stop_polling_; + Notification polling_stopped_; + + // The main PollLoop for the event manager runs in this threadpool. + thread::ThreadPool threadpool_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc new file mode 100644 index 0000000000..30ca1ff187 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc @@ -0,0 +1,152 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" + +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/stream_executor.h" +#include + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { + +class TEST_EventMgrHelper { + public: + explicit TEST_EventMgrHelper(EventMgr* em) : em_(em) {} + + int queue_size() { + mutex_lock l(em_->mu_); + return em_->used_events_.size(); + } + + int free_size() { + mutex_lock l(em_->mu_); + return em_->free_events_.size(); + } + + void QueueTensors(perftools::gputools::Stream* stream, + std::vector* tensors) { + mutex_lock l(em_->mu_); + em_->QueueTensors(stream, tensors); + } + + void PollEvents(bool is_dedicated_poller) { + mutex_lock l(em_->mu_); + em_->PollEvents(is_dedicated_poller); + } + + private: + EventMgr* em_; +}; + +namespace { + +TEST(EventMgr, Empty) { + auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); + EventMgr em(stream_exec); + TEST_EventMgrHelper th(&em); + EXPECT_EQ(0, th.queue_size()); + EXPECT_EQ(0, th.free_size()); +} + +// Delaying polling until after several enqueings should grow the +// total number of allocated events. Once we have enough events for +// the max simultaneously pending, we should not allocate any more. +TEST(EventMgr, DelayedPolling) { + auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); + EventMgr em(stream_exec); + TEST_EventMgrHelper th(&em); + EXPECT_EQ(0, th.queue_size()); + std::vector* v = nullptr; + std::unique_ptr stream(new gpu::Stream(stream_exec)); + CHECK(stream.get()); + stream->Init(); + for (int i = 0; i < 5; ++i) { + v = new std::vector; + th.QueueTensors(stream.get(), v); + EXPECT_EQ(i + 1, th.queue_size()); + EXPECT_EQ(0, th.free_size()); + } + th.PollEvents(false); + EXPECT_EQ(0, th.queue_size()); + EXPECT_EQ(5, th.free_size()); + for (int j = 0; j < 2; ++j) { + for (int i = 0; i < 5; ++i) { + v = new std::vector; + th.QueueTensors(stream.get(), v); + EXPECT_EQ(i + 1, th.queue_size()); + EXPECT_EQ(4 - i, th.free_size()); + } + th.PollEvents(false); + EXPECT_EQ(0, th.queue_size()); + EXPECT_EQ(5, th.free_size()); + } +} + +// Immediate polling should require only one event to be allocated. +TEST(EventMgr, ImmediatePolling) { + auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); + EventMgr em(stream_exec); + TEST_EventMgrHelper th(&em); + EXPECT_EQ(0, th.queue_size()); + EXPECT_EQ(0, th.free_size()); + std::vector* v = nullptr; + std::unique_ptr stream(new gpu::Stream(stream_exec)); + CHECK(stream.get()); + stream->Init(); + for (int i = 0; i < 5; ++i) { + v = new std::vector; + em.ThenDeleteTensors(stream.get(), v); + EXPECT_EQ(0, th.queue_size()); + EXPECT_EQ(1, th.free_size()); + } +} + +// If we delay polling by more than 1 second, the backup polling loop +// should clear the queue. +TEST(EventMgr, LongDelayedPolling) { + auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); + EventMgr em(stream_exec); + TEST_EventMgrHelper th(&em); + EXPECT_EQ(0, th.queue_size()); + EXPECT_EQ(0, th.free_size()); + std::vector* v = nullptr; + std::unique_ptr stream(new gpu::Stream(stream_exec)); + CHECK(stream.get()); + stream->Init(); + for (int i = 0; i < 5; ++i) { + v = new std::vector; + th.QueueTensors(stream.get(), v); + EXPECT_EQ(1 + i, th.queue_size()); + EXPECT_EQ(0, th.free_size()); + } + sleep(1); + EXPECT_EQ(0, th.queue_size()); + EXPECT_EQ(5, th.free_size()); +} + +// Deleting the EventMgr when events are still pending should shut +// down gracefully. +TEST(EventMgr, NonEmptyShutdown) { + auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie(); + EventMgr em(stream_exec); + TEST_EventMgrHelper th(&em); + EXPECT_EQ(0, th.queue_size()); + EXPECT_EQ(0, th.free_size()); + std::vector* v = nullptr; + std::unique_ptr stream(new gpu::Stream(stream_exec)); + CHECK(stream.get()); + stream->Init(); + for (int i = 0; i < 5; ++i) { + v = new std::vector; + th.QueueTensors(stream.get(), v); + EXPECT_EQ(1 + i, th.queue_size()); + EXPECT_EQ(0, th.free_size()); + } +} + +} // namespace +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/common_runtime/gpu/gpu_init.cc b/tensorflow/core/common_runtime/gpu/gpu_init.cc new file mode 100644 index 0000000000..631a47eb91 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_init.cc @@ -0,0 +1,147 @@ +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" + +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/stream_executor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { + +namespace { + +std::unique_ptr, bool>> GetPeerAccessMap( + gpu::Platform* platform, int device_count) { + auto* map = new std::map, bool>; + for (int i = 0; i < device_count; ++i) { + for (int j = 0; j < device_count; ++j) { + gpu::StreamExecutor* from = platform->ExecutorForDevice(i).ValueOrDie(); + gpu::StreamExecutor* to = platform->ExecutorForDevice(j).ValueOrDie(); + (*map)[{i, j}] = from->CanEnablePeerAccessTo(to); + } + } + + return std::unique_ptr, bool>>{map}; +} + +Status EnablePeerAccess(gpu::Platform* platform, int device_count) { + for (int i = 0; i < device_count; ++i) { + for (int j = 0; j < device_count; ++j) { + gpu::StreamExecutor* from = platform->ExecutorForDevice(i).ValueOrDie(); + gpu::StreamExecutor* to = platform->ExecutorForDevice(j).ValueOrDie(); + + if (from->CanEnablePeerAccessTo(to)) { + auto status = from->EnablePeerAccessTo(to); + if (!status.ok()) { + return errors::Internal(status.ToString()); + } + } else { + LOG(INFO) << "cannot enable peer access from device ordinal " << i + << " to device ordinal " << j; + } + } + } + return Status::OK(); +} + +static void InitGPU() { + auto result = gpu::MultiPlatformManager::PlatformWithName("CUDA"); + if (!result.ok()) { + LOG(WARNING) + << "Not initializing the GPU, could not create GPU MachineManager. " + << "Error: " << result.status(); + return; + } + + gpu::Platform* platform = result.ValueOrDie(); + + int dev_count = platform->VisibleDeviceCount(); + + if (dev_count == 0) { + LOG(INFO) << "No GPU devices available on machine."; + return; + } + + for (int i = 0; i < dev_count; ++i) { + auto stream_exec = platform->ExecutorForDevice(i).ValueOrDie(); + int64 free_bytes; + int64 total_bytes; + if (!stream_exec->DeviceMemoryUsage(&free_bytes, &total_bytes)) { + // Logs internally on failure. + free_bytes = 0; + total_bytes = 0; + } + const auto& description = stream_exec->GetDeviceDescription(); + int cc_major; + int cc_minor; + if (!description.cuda_compute_capability(&cc_major, &cc_minor)) { + // Logs internally on failure. + cc_major = 0; + cc_minor = 0; + } + LOG(INFO) << "Found device " << i << " with properties: " + << "\nname: " << description.name() << "\nmajor: " << cc_major + << " minor: " << cc_minor << " memoryClockRate (GHz) " + << description.clock_rate_ghz() << "\npciBusID " + << description.pci_bus_id() << "\nTotal memory: " + << strings::HumanReadableNumBytes(total_bytes) + << "\nFree memory: " + << strings::HumanReadableNumBytes(free_bytes); + } + + // Enable peer access + + auto status = EnablePeerAccess(platform, dev_count); + if (!status.ok()) { + LOG(FATAL) << "could not enable peer access for GPU devices: " << status; + } + + // Print out a matrix showing which devices can DMA to one + // another. + auto access_map = GetPeerAccessMap(platform, dev_count); + string line_buf = "DMA: "; + for (int i = 0; i < dev_count; ++i) { + strings::StrAppend(&line_buf, i, " "); + } + LOG(INFO) << line_buf; + for (int i = 0; i < dev_count; ++i) { + line_buf = strings::StrCat(i, ": "); + for (int j = 0; j < dev_count; ++j) { + if ((*access_map)[{i, j}]) { + line_buf.append("Y "); + } else { + line_buf.append("N "); + } + } + LOG(INFO) << line_buf; + } +} + +static bool InitModule() { + InitGPU(); + return true; +} + +} // namespace + +gpu::Platform* GPUMachineManager() { + // Create the machine manager singleton and initialize the GPUs only + // once. + static bool init = InitModule(); + CHECK(init); // Avoids compiler warning that init is unused. + + auto result = gpu::MultiPlatformManager::PlatformWithName("CUDA"); + if (!result.ok()) { + return nullptr; + } + + return result.ValueOrDie(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_init.h b/tensorflow/core/common_runtime/gpu/gpu_init.h new file mode 100644 index 0000000000..d126a8b1ca --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_init.h @@ -0,0 +1,19 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_ + +namespace perftools { +namespace gputools { +class Platform; +} // namespace gputools +} // namespace perftools + +namespace tensorflow { + +// Returns the GPU machine manager singleton, creating it and +// initializing the GPUs on the machine if needed the first time it is +// called. +perftools::gputools::Platform* GPUMachineManager(); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc new file mode 100644 index 0000000000..08ff55e221 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc @@ -0,0 +1,371 @@ +#include "tensorflow/core/common_runtime/gpu/gpu_region_allocator.h" + +//#include "base/commandlineflags.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h" +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +#if defined(PLATFORM_GOOGLE) +DEFINE_bool(brain_gpu_region_allocator_heap_check_on_destruction, true, + "If true, the CUDA gpu manager checks that all allocated " + "memory through the GPU memory pool implementation has been " + "freed."); + +DEFINE_int64(brain_gpu_region_allocator_region_size, 0, + "If > 0, sets the default chunk-size allocatable from GPU memory. " + "Else defaults to entire GPU memory."); + +#else +bool FLAGS_brain_gpu_region_allocator_heap_check_on_destruction = true; +tensorflow::int64 FLAGS_brain_gpu_region_allocator_region_size = 0; +#endif + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { + +GPURegionAllocator::GPURegionAllocator(int device_id, size_t total_bytes) + : device_id_(device_id), total_bytes_(total_bytes) { + // Get a pointer to the stream_executor for this device + stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie(); + + // Set the region size based on explicit user request, or based on + // total GPU capacity. + if (FLAGS_brain_gpu_region_allocator_region_size > 0) { + region_size_ = FLAGS_brain_gpu_region_allocator_region_size; + } else { + region_size_ = static_cast(total_bytes_); + } + + LOG(INFO) << "Setting region size to " << region_size_; +} + +GPURegionAllocator::~GPURegionAllocator() { + if (FLAGS_brain_gpu_region_allocator_heap_check_on_destruction) { + CheckForMemoryLeaks(); + } + + gtl::STLDeleteValues(&chunk_map_); + + for (auto r : regions_) { + gpu::DeviceMemoryBase gpu_ptr{r->ptr}; + stream_exec_->Deallocate(&gpu_ptr); + delete r; + } +} + +void* GPURegionAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { + static const int64 kMaxMillisToWait = 10000; // 10 seconds + return retry_helper_.AllocateRaw( + [this](size_t a, size_t nb, bool v) { + return AllocateRawInternal(a, nb, v); + }, + kMaxMillisToWait, alignment, num_bytes); +} + +void* GPURegionAllocator::AllocateRawInternal(size_t alignment, + size_t num_bytes, + bool dump_log_on_failure) { + if (num_bytes == 0) { + LOG(ERROR) << "tried to allocate 0 bytes"; + return nullptr; + } + size_t chunk_size = ChunkSize(num_bytes); + + VLOG(2) << "chunk_size " << chunk_size << " from num_bytes " + << strings::HumanReadableNumBytes(num_bytes); + mutex_lock l(lock_); + Pool* pool = &pools_[chunk_size]; + if (pool->num_free == 0) { + if (!ExpandPool(pool, chunk_size, num_bytes, dump_log_on_failure)) { + if (dump_log_on_failure) { + LOG(WARNING) << "Out of GPU memory, see memory state dump above"; + } + return nullptr; + } + } + CHECK_LT(0, pool->num_free); + CHECK(pool->first); + CHECK(pool->last); + Chunk* c = pool->first; + CHECK(c); + CHECK(!c->in_use); + + c->in_use = true; + // Move c to the back of the queue. + if (c->next != nullptr) { + pool->first = c->next; + pool->first->prev = nullptr; + c->next = nullptr; + } + + if (pool->last != c) { + pool->last->next = c; + c->prev = pool->last; + pool->last = c; + } + pool->num_free--; + pool->cumulative_malloced++; + + void* rv = c->ptr; + c->bytes_allocated = num_bytes; + + VLOG(2) << "new ptr " << rv; + return rv; +} + +void GPURegionAllocator::DeallocateRaw(void* ptr) { + retry_helper_.DeallocateRaw([this](void* p) { DeallocateRawInternal(p); }, + ptr); +} + +void GPURegionAllocator::DeallocateRawInternal(void* ptr) { + VLOG(2) << "DeallocateRaw: " << ptr; + if (ptr == nullptr) { + LOG(ERROR) << "tried to deallocate nullptr"; + return; + } + + mutex_lock l(lock_); + ChunkMap::const_iterator iter = chunk_map_.find(ptr); + CHECK(iter != chunk_map_.end()); + + Chunk* c = iter->second; + VLOG(2) << "chunk of size " << c->size << " at " << c; + + Pool* pool = &(pools_[c->size]); + // Move chunk to head of queue, and mark free. + DCHECK(c->in_use); + c->in_use = false; + if (c->prev) c->prev->next = c->next; + if (c->next) c->next->prev = c->prev; + if (pool->first == c) pool->first = c->next; + if (pool->last == c) pool->last = c->prev; + c->next = pool->first; + c->prev = nullptr; + if (c->next) c->next->prev = c; + pool->first = c; + if (pool->last == nullptr) pool->last = c; + pool->num_free++; + pool->cumulative_freed++; +} + +bool GPURegionAllocator::ExpandPool(Pool* pool, size_t chunk_size, + size_t requested_size, + bool dump_log_on_failure) { + VLOG(1) << "ExpandPool of " << chunk_size << " from " << pool->num_chunks + << " current members"; + DCHECK_NE(0, chunk_size); + // If chunk_size is < 4096, double the pool size. Otherwise + // just increase by one. + int num_chunks = pool->num_chunks; + if (num_chunks == 0) { + if (chunk_size > 4096) { + num_chunks = 1; + } else { + num_chunks = 4096 / chunk_size; + } + } + // For larger chunks, limit the amount of expansion. + size_t aggregate_size = num_chunks * chunk_size; + if (aggregate_size > (1 << 20)) { + num_chunks = static_cast( + std::max(static_cast(1), (1 << 20) / chunk_size)); + } + while (num_chunks > 0) { + Region* r = (regions_.empty() ? nullptr : regions_.back()); + if (r == nullptr || + (((r->ptr + r->size) - r->next) < static_cast(chunk_size))) { + // Current region is not large enough to accommodate another chunk. + while (r == nullptr || (((r->ptr + r->size) - r->next) < + static_cast(chunk_size))) { + // Get another region. + size_t this_region_size = std::max(region_size_, chunk_size); + + // Check if we would exceed our limit. + if (allocated_memory_ + this_region_size > total_bytes_) { + if (dump_log_on_failure) DumpMemoryLog(); + return false; + } + + // Perform the allocation, still checking that the allocator + // has not run out of memory. + gpu::DeviceMemory gpu_mem = + stream_exec_->AllocateArray(this_region_size); + if (gpu_mem == nullptr) { + if (dump_log_on_failure) DumpMemoryLog(); + return false; + } + + // We never release memory once expanded. + allocated_memory_ += this_region_size; + + Region* nr = new Region; + nr->ptr = static_cast(gpu_mem.opaque()); + + if (VLOG_IS_ON(2)) { + int64 free_bytes; + int64 total_bytes; + if (stream_exec_->DeviceMemoryUsage(&free_bytes, &total_bytes)) { + VLOG(2) << "free " << free_bytes << " total " << total_bytes; + } else { + // Note: stream_exec call also logs internally on failure. + VLOG(2) << "could not retrieve memory usage"; + } + } + VLOG(1) << "new Region of size " << this_region_size << " at " + << static_cast(nr->ptr) << " on device " << device_id_; + r = nr; + r->size = this_region_size; + r->next = r->ptr; + regions_.push_back(r); + + for (auto visitor : region_visitors_) { + visitor(r->ptr, r->size); + } + } + } else { + // Allocate a new chunk and push on front of Pool. + Chunk* c = new Chunk; + c->ptr = r->next; + chunk_map_[c->ptr] = c; + c->size = chunk_size; + r->next += chunk_size; + c->next = pool->first; + if (c->next != nullptr) c->next->prev = c; + pool->first = c; + if (pool->last == nullptr) pool->last = c; + pool->num_chunks++; + pool->num_free++; + --num_chunks; + } + } + + return true; +} + +void GPURegionAllocator::CheckForMemoryLeaks() { + std::vector errors; + mutex_lock l(lock_); // could use reader lock + for (auto pool_map : pools_) { + const Pool& p = pool_map.second; + Chunk* curr_chunk = p.first; + while (curr_chunk != nullptr) { + if (curr_chunk->in_use) { + errors.push_back( + strings::StrCat("Unfreed chunk of size ", curr_chunk->size)); + } + curr_chunk = curr_chunk->next; + } + } + if (!errors.empty()) { + LOG(FATAL) << "GPU Memory leaks:\n" << str_util::Join(errors, "\n"); + } +} + +// Since there's no merging of chunks once allocated, we want to +// maximize their reusablity (which argues for fewer, larger sizes), +// while minimizing waste (which argues for tight-fitting sizes). +// +// The smallest unit of allocation is 256 bytes. +// NOTE(tucker): akrizhevsky says that nvidia's memory manager always +// aligns to 256 bytes, and doing so results in significant speedup. +// +// Up to 2^16 bytes we only allocate in powers of 2. +// +// Above that, we pick a max-waste which is the largest power +// of 2 <= 1/16 of the requested size, then round up to the nearest +// multiple of max_waste. +// +// static +size_t GPURegionAllocator::ChunkSize(size_t bytes) { + if (bytes <= 256) { + return 256; + } else if (bytes <= (1 << 16)) { + return 1uLL << Log2Ceiling64(bytes); + } else { + // 1/16th of requested size + size_t max_waste = 1uLL << (Log2Ceiling64(bytes) - 4); + return (bytes + max_waste) & (~(max_waste - 1)); + } +} + +void GPURegionAllocator::AddAllocVisitor(Visitor visitor) { + VLOG(1) << "AddVisitor"; + mutex_lock l(lock_); + region_visitors_.push_back(visitor); + for (auto region : regions_) { + visitor(region->ptr, region->size); + } +} + +void GPURegionAllocator::DumpMemoryLog() { + size_t region_bytes = 0; + for (auto r : regions_) { + region_bytes += r->size; + } + size_t chunk_bytes = 0; + std::vector chunk_sizes; + for (auto i : pools_) { + chunk_sizes.push_back(i.first); + } + std::sort(chunk_sizes.begin(), chunk_sizes.end()); + for (auto i : chunk_sizes) { + int32 chunks_in_use = 0; + const Pool& p = pools_[i]; + chunk_bytes += i * p.num_chunks; + + if (p.num_chunks > 0) { + // Iterate backwards (allocated chunks are last). + Chunk* curr_chunk = p.last; + while (curr_chunk != nullptr) { + if (curr_chunk->in_use) { + ++chunks_in_use; + } + curr_chunk = curr_chunk->prev; + if (curr_chunk == p.first) { + break; + } + } + } + + LOG(INFO) << "Chunk size: " << i << " (" + << strings::HumanReadableNumBytes(i) << ") Pool: " << p.ToString() + << "\nNumber of chunks: " << p.num_chunks + << ", in_use chunks: " << chunks_in_use; + } + + LOG(INFO) << "Aggregate Region Memory: " << region_bytes << " (" + << strings::HumanReadableNumBytes(region_bytes) << ")"; + LOG(INFO) << "Aggregate Chunk Memory: " << chunk_bytes << " (" + << strings::HumanReadableNumBytes(chunk_bytes) << ")"; +} + +bool GPURegionAllocator::TracksAllocationSizes() { return true; } + +size_t GPURegionAllocator::RequestedSize(void* ptr) { + mutex_lock l(lock_); + auto it = chunk_map_.find(ptr); + CHECK(it != chunk_map_.end()) + << "Asked for requested size of pointer we never allocated: " << ptr; + auto c = it->second; + return c->bytes_allocated; +} + +size_t GPURegionAllocator::AllocatedSize(void* ptr) { + mutex_lock l(lock_); + auto it = chunk_map_.find(ptr); + CHECK(it != chunk_map_.end()) + << "Asked for allocated size of pointer we never allocated: " << ptr; + auto c = it->second; + return c->size; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_region_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_region_allocator.h new file mode 100644 index 0000000000..1a250b6ede --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_region_allocator.h @@ -0,0 +1,146 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_REGION_ALLOCATOR_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_REGION_ALLOCATOR_H_ + +#include +#include +#include +#include + +#include "tensorflow/stream_executor/stream_executor.h" +#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h" +#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +class GPURegionAllocator : public VisitableAllocator { + public: + // 'device_id' must be a valid device on the machine. + // + // total_bytes is how many bytes this allocator should allocate up + // to. This may be less than the total available. + explicit GPURegionAllocator(int device_id, size_t total_bytes); + ~GPURegionAllocator() override; + + string Name() override { return "gpu_region"; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + void AddAllocVisitor(Visitor visitor) override; + // Does nothing, because regions are never freed. + void AddFreeVisitor(Visitor visitor) override {} + + bool TracksAllocationSizes() override; + size_t RequestedSize(void* ptr) override; + size_t AllocatedSize(void* ptr) override; + + private: + // A Chunk is the header on a single piece of memory given back + // in response to an AllocateRaw() call. + struct Chunk { + char* ptr; // pointer to granted GPU buffer. + size_t size; // Full size of GPU buffer. + size_t bytes_allocated; // Bytes asked for by client. + bool in_use; + Chunk* prev; // Used for chaining in pool. + Chunk* next; + Chunk() + : ptr(nullptr), + size(0), + bytes_allocated(0), + in_use(false), + prev(nullptr), + next(nullptr) {} + }; + + // A Pool is a collection of same-sized Chunks. + struct Pool { + int num_chunks; // total chunks in this pool + int num_free; // total free chunks in this pool + int64 cumulative_malloced; // number of chunks malloced so far + int64 cumulative_freed; // number of chunks freed so far + + // double-linked ring of chunks; all free chunks precede all + // granted chunks + Chunk* first; + Chunk* last; + Pool() + : num_chunks(0), + num_free(0), + cumulative_malloced(0), + cumulative_freed(0), + first(nullptr), + last(nullptr) {} + + string ToString() const { + return strings::StrCat("chunks: ", num_chunks, " free: ", num_free, + " cumulative malloc: ", cumulative_malloced, + " cumulative freed: ", cumulative_freed); + } + }; + + // A Region is a single area of GPU memory that has been + // reserved by this class and carved up into Chunks. + struct Region { + char* ptr; // base GPU ptr + char* next; // frontier of unused part of region + size_t size; + Region() : ptr(nullptr), size(0) {} + }; + + // Calculate size of chunk for an allocation of this size. + // Min chunk size is 16, for alignment. + // For larger sizes, we round up somewhat so there are fewer + // size-specific pools. + static size_t ChunkSize(size_t bytes); + + void* AllocateRawInternal(size_t alignment, size_t num_bytes, + bool dump_log_on_failure); + void DeallocateRawInternal(void* ptr); + + bool ExpandPool(Pool* p, size_t chunk_size, size_t requested_size, + bool dump_log_on_failure) EXCLUSIVE_LOCKS_REQUIRED(lock_); + + // Inspects region maps and crashes with debug information if there + // are any memory leaks as detected by the region allocator. + void CheckForMemoryLeaks() LOCKS_EXCLUDED(lock_); + + void DumpMemoryLog() EXCLUSIVE_LOCKS_REQUIRED(lock_); + + perftools::gputools::StreamExecutor* stream_exec_; // Not owned. + + typedef std::unordered_map PoolMap; + typedef std::unordered_map ChunkMap; + + GPUAllocatorRetry retry_helper_; + mutable mutex lock_; + PoolMap pools_ GUARDED_BY(lock_); + + // Owns regions. + std::vector regions_ GUARDED_BY(lock_); + + // Maps from GPU ptr to Chunk owning it. + // + // Owns chunks. + ChunkMap chunk_map_ GUARDED_BY(lock_); + + // Called once on each region, ASAP. + std::vector region_visitors_ GUARDED_BY(lock_); + + const int device_id_; + + // Total amount of memory (in bytes) available to this Allocator + const size_t total_bytes_; + + // Total amount of memory allocated to regions. + size_t allocated_memory_ = 0; + + size_t region_size_ = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(GPURegionAllocator); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_REGION_ALLOCATOR_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_region_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_region_allocator_test.cc new file mode 100644 index 0000000000..07b0dd57f6 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_region_allocator_test.cc @@ -0,0 +1,71 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/common_runtime/gpu/gpu_region_allocator.h" + +#include +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/stream_executor/stream_executor.h" +#include + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { +namespace { + +TEST(GPURegionAllocatorTest, Simple) { + GPURegionAllocator a(0, 1 << 26); + std::vector ptrs; + for (int s = 1; s < 1024; s++) { + void* raw = a.AllocateRaw(1, s); + ptrs.push_back(raw); + } + std::sort(ptrs.begin(), ptrs.end()); + for (int i = 0; i < ptrs.size(); i++) { + if (i > 0) { + CHECK_NE(ptrs[i], ptrs[i - 1]); // No dups + } + a.DeallocateRaw(ptrs[i]); + } + float* t1 = a.Allocate(1024); + double* t2 = a.Allocate(1048576); + a.Deallocate(t1); + a.Deallocate(t2); +} + +TEST(GPURegionAllocatorTest, CheckMemLeak) { + EXPECT_DEATH( + { + GPURegionAllocator a(0, 1 << 26); + float* t1 = a.Allocate(1024); + if (t1) { + LOG(INFO) << "Not deallocating"; + } + }, + ""); +} + +TEST(GPURegionAllocatorTest, TracksSizes) { + GPURegionAllocator a(0, 1 << 26); + EXPECT_EQ(true, a.TracksAllocationSizes()); +} + +TEST(GPURegionAllocatorTest, AllocatedVsRequested) { + GPURegionAllocator a(0, 1 << 26); + float* t1 = a.Allocate(1); + EXPECT_EQ(sizeof(float), a.RequestedSize(t1)); + + // Minimum allocation size if 256 + EXPECT_EQ(256, a.AllocatedSize(t1)); + + a.Deallocate(t1); +} + +} // namespace +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc b/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc new file mode 100644 index 0000000000..ca86c7fa06 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc @@ -0,0 +1,97 @@ +#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h" + +#include +#include +#include +#include + +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace gpu_stream_util { + +Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts, + std::unordered_map* node_to_stream_id) { + VLOG(1) << "AssignStreams"; + Status status; + + // Sanity check arguments. + if (graph == nullptr) + status.Update(errors::InvalidArgument("Bad graph argument supplied.")); + if (node_to_stream_id == nullptr) { + status.Update( + errors::InvalidArgument("Bad node_to_stream_id argument supplied.")); + } + if ((opts.max_streams < 1) || (opts.send_stream >= opts.max_streams) || + (opts.recv_stream >= opts.max_streams) || + (opts.const_stream >= opts.max_streams) || + (opts.compute_stream >= opts.max_streams)) { + status.Update(errors::InvalidArgument("Bad graph argument supplied.")); + } + TF_RETURN_IF_ERROR(status); + + // Topologically sort the nodes. + std::vector order; + GetReversePostOrder(*graph, &order); + if (VLOG_IS_ON(2)) { + for (Node* n : order) { + const int node_id = n->id(); + VLOG(2) << "Node " << node_id << " " << n->type_string() << " " + << n->name() << " " << n->in_edges().size() << " inputs"; + for (const Edge* e : n->in_edges()) { + VLOG(2) << " Edge from " << e->src()->id() << " " << e->src()->name() + << " fanout " << e->src()->out_edges().size(); + } + } + } + // We perform stream assigmnent assuming a large number of + // stream IDs and then map these down to the required number of streams + // using simple round-robin. + // Stream Assignment strategy: + // 1. Nodes with zero inputs are always be executed on a + // fresh stream. + // 2. Try to execute a node on the same stream as one of its + // inputs to avoid inter-stream dependencies. + // 3. If any input comes from a node with a large fanout then + // perhaps an indication that it is shared between parallel + // streams of work. We choose a new stream here so that all consumers + // of the tensor are likely to run in parallel. + int highest_stream_id = -1; + for (Node* n : order) { + VLOG(3) << "Inspecting node " << n->DebugString(); + const int node_id = n->id(); + const string& op = n->type_string(); + + // Determine a suitable stream to use. + int stream_id = highest_stream_id + 1; + for (const Edge* e : n->in_edges()) { + const int fanout = e->src()->out_edges().size(); + if (fanout == 1) { + stream_id = (*node_to_stream_id)[e->src()->id()]; + break; + } + } + // Override stream for specific op types. + if (op == "_Send") { + if (opts.send_stream >= 0) stream_id = opts.send_stream; + } else if (op == "_Recv") { + if (opts.recv_stream >= 0) stream_id = opts.recv_stream; + } else if (op == "Const") { + if (opts.const_stream >= 0) stream_id = opts.const_stream; + } else { + if (opts.compute_stream >= 0) stream_id = opts.compute_stream; + } + + (*node_to_stream_id)[node_id] = stream_id % opts.max_streams; + highest_stream_id = std::max(stream_id, highest_stream_id); + } + VLOG(1) << "Identified " << highest_stream_id << " candidate streams for " + << order.size() << " nodes."; + + return Status::OK(); +} + +} // namespace gpu_stream_util +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util.h b/tensorflow/core/common_runtime/gpu/gpu_stream_util.h new file mode 100644 index 0000000000..e1c623382c --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util.h @@ -0,0 +1,30 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_ + +#include + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace gpu_stream_util { + +struct AssignStreamsOpts { + int32 max_streams = 1; + // The following options specify a stream to use for specific op + // types. The value -1 allows ops to be assigned to any stream. + int32 send_stream = -1; + int32 recv_stream = -1; + int32 const_stream = -1; + int32 compute_stream = -1; +}; + +// Given the input graph, assigns every node in the graph with a +// stream_id that should be used. +Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts, + std::unordered_map* node_to_stream_id); + +} // namespace gpu_stream_util +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc new file mode 100644 index 0000000000..5c426caaef --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc @@ -0,0 +1,137 @@ +#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h" + +#include +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +class GpuStreamUtilTest : public OpsTestBase { + protected: + void SetUp() override { RequireDefaultOps(); } +}; + +TEST_F(GpuStreamUtilTest, BogusOpts) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Graph g(OpRegistry::Global()); + ASSERT_OK(b.ToGraph(&g)); + std::unordered_map node_to_stream_id; + gpu_stream_util::AssignStreamsOpts opts; + Status status; + status = gpu_stream_util::AssignStreams(nullptr, opts, &node_to_stream_id); + EXPECT_FALSE(status.ok()); + status = gpu_stream_util::AssignStreams(&g, opts, nullptr); + EXPECT_FALSE(status.ok()); + opts.max_streams = 0; + status = gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id); + EXPECT_FALSE(status.ok()); + opts.max_streams = 1; + opts.compute_stream = 5; + status = gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id); + EXPECT_FALSE(status.ok()); +} + +TEST_F(GpuStreamUtilTest, EmptyGraph) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Graph g(OpRegistry::Global()); + ASSERT_OK(b.ToGraph(&g)); + std::unordered_map node_to_stream_id; + gpu_stream_util::AssignStreamsOpts opts; + ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id)); + EXPECT_EQ(2, node_to_stream_id.size()); // _SOURCE and _SINK +} + +TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()), + ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts()); + Graph g(OpRegistry::Global()); + ASSERT_OK(b.ToGraph(&g)); + + std::unordered_map node_to_stream_id; + gpu_stream_util::AssignStreamsOpts opts; + ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id)); + + // There should be 5 nodes assigned. + EXPECT_EQ(5, node_to_stream_id.size()); + + // All of them should have stream 0. + for (const auto& it : node_to_stream_id) { + EXPECT_EQ(0, it.second); + } +} + +TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()), + ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts()); + Graph g(OpRegistry::Global()); + ASSERT_OK(b.ToGraph(&g)); + + std::unordered_map node_to_stream_id; + gpu_stream_util::AssignStreamsOpts opts; + opts.max_streams = 3; + ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id)); + + // There should be 5 nodes assigned. + EXPECT_EQ(5, node_to_stream_id.size()); + + // All of them should have a stream in the range [0..max_streams). + for (const auto& it : node_to_stream_id) { + EXPECT_GE(it.second, 0); + EXPECT_LT(it.second, opts.max_streams); + } +} + +TEST_F(GpuStreamUtilTest, StreamOverrides) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::_Recv(DT_FLOAT, "input", "/cpu:0", 0, "/gpu:0", + b.opts().WithName("input")); + auto n = ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()), + ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts()); + ops::_Send(n, "output", "/gpu:0", 0, "/cpu:0", b.opts().WithName("output")); + Graph g(OpRegistry::Global()); + ASSERT_OK(b.ToGraph(&g)); + + // Perform stream assignment using a large number of streams, but with + // op types constrained to specific streams. + std::unordered_map node_to_stream_id; + gpu_stream_util::AssignStreamsOpts opts; + opts.max_streams = 100; + opts.const_stream = 90; + opts.send_stream = 91; + opts.recv_stream = 92; + opts.compute_stream = 93; + ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id)); + + // There should be 7 nodes assigned. + EXPECT_EQ(7, node_to_stream_id.size()); // including _SOURCE and _SINK + + // Nodes should be assigned to streams by op type. + for (const auto& it : node_to_stream_id) { + Node* n = g.FindNodeId(it.first); + const string op = n->type_string(); + const int stream = it.second; + if (op == "Const") { + EXPECT_EQ(stream, 90); + } else if (op == "_Send") { + EXPECT_EQ(stream, 91); + } else if (op == "_Recv") { + EXPECT_EQ(stream, 92); + } else { // Compute. + EXPECT_EQ(stream, 93); + } + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc new file mode 100644 index 0000000000..a6a3ce01fc --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -0,0 +1,345 @@ +#include "tensorflow/core/common_runtime/gpu/gpu_util.h" + +//#include "base/commandlineflags.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tensor_coding.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/common_runtime/gpu/dma_helper.h" +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/core/util/util.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/stream_executor.h" + +#include "tensorflow/core/platform/stream_executor_util.h" + +#if defined(PLATFORM_GOOGLE) +DEFINE_int64(brain_gpu_util_debug_string_maxlen, 128, + "When dumping gpu memory, prints up to this many bytes."); + +DECLARE_bool(record_mem_types); +#else +tensorflow::int64 FLAGS_brain_gpu_util_debug_string_maxlen = 128; +bool FLAGS_EXPERIMENTAL_brain_gpu_multi_stream = false; +extern bool FLAGS_record_mem_types; +#endif + +using perftools::gputools::DeviceMemoryBase; +using perftools::gputools::DeviceMemory; +using perftools::gputools::Stream; + +namespace tensorflow { + +namespace gpu = ::perftools::gputools; + +/*static*/ +void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev, + const DeviceContext* device_context, + TensorProto* proto, bool is_dead, + StatusCallback done) { + VLOG(1) << "SetProtoFromGPU device_context " << device_context; + // Tensor values need to be copied from GPU to CPU ram so that + // we can build the protobuf response for a RecvTensor RPC. + // "device context" identifies the stream where the _Send op executed. + CHECK(device_context); + gpu::Stream* stream = + static_cast(device_context)->stream(); + + if (!DMAHelper::CanUseDMA(&tensor)) { + done(errors::Internal(strings::StrCat( + "GPU copy from non-DMA ", DataTypeString(tensor.dtype()), "tensor"))); + return; + } + proto->set_dtype(tensor.dtype()); + tensor.shape().AsProto(proto->mutable_tensor_shape()); + // Prepare a Cord with the right data buf size, and DMA the + // data over from the GPU buffer. Note that 0-size tensors + // do not have a backing buffer. + const size_t num_bytes = is_dead ? 0 : tensor.TotalBytes(); + if (num_bytes > 0) { + port::Tracing::ScopedAnnotation annotation("SetProtoFromGPU"); + Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + char* mb = alloc->Allocate(num_bytes); + const char* src_ptr = + reinterpret_cast(DMAHelper::base(&tensor)); + DeviceMemoryBase gpu_src_ptr(const_cast(src_ptr), num_bytes); + stream->ThenMemcpy(mb, gpu_src_ptr, num_bytes); + // Use of tensor may outlive stack scope, so keep a ref. + Tensor* tensor_ref = new Tensor(tensor); + dev->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + stream, [stream, done, proto, mb, num_bytes, alloc, tensor_ref]() { + if (!stream->ok()) { + done(errors::Internal("SetProtoFromGPU: GPU Memcpy failed")); + // TODO(pbar) We currently have no way to recover the + // worker from a GPU stream in the error state. Until + // there is a way to reset the CUDA driver, it is + // preferable to crash the process and restart. Tracked + // under b/23717097 + LOG(FATAL) << "SetProtoFromGPU: GPU Memcpy failed"; + return; + } + delete tensor_ref; + port::CopyFromArray(proto->mutable_tensor_content(), mb, num_bytes); + alloc->Deallocate(mb); + done(Status::OK()); + }); + } else { + done(Status::OK()); + } +} + +typedef ProcessState::MemDesc PMD; + +/*static*/ +void GPUUtil::CopyViaDMA(const string& edge_name, + DeviceContext* send_dev_context, + DeviceContext* recv_dev_context, Device* src, + Device* dst, AllocatorAttributes src_alloc_attr, + AllocatorAttributes dst_alloc_attr, + const Tensor* input, Tensor* output, + StatusCallback done) { + port::Tracing::ScopedAnnotation annotation(edge_name); + VLOG(1) << "CopyViaDMA " << edge_name; + size_t total_bytes = input->TotalBytes(); + // Note that 0-size tensors have no backing buffer. + if (total_bytes > 0) { + const void* src_ptr = DMAHelper::base(input); + void* dst_ptr = DMAHelper::base(output); + VLOG(2) << "src_ptr " << src_ptr << " dst_ptr " << dst_ptr; + if (FLAGS_record_mem_types) { + ProcessState::MemDesc smd = ProcessState::singleton()->PtrType(src_ptr); + ProcessState::MemDesc dmd = ProcessState::singleton()->PtrType(dst_ptr); + VLOG(0) << "Src " << smd.DebugString() << " Dst " << dmd.DebugString(); + if (smd.loc == PMD::CPU && dmd.loc == PMD::GPU && (!smd.gpu_registered)) { + LOG(WARNING) << "CPU -> GPU no reg for " << edge_name; + } + if (dmd.loc == PMD::CPU && smd.loc == PMD::GPU && (!dmd.gpu_registered)) { + LOG(WARNING) << "GPU -> CPU no reg for " << edge_name; + } + } + + auto src_device_type = src->attributes().device_type(); + auto dst_device_type = dst->attributes().device_type(); + + bool non_cpu_src = (!src_alloc_attr.on_host() && + src_device_type != DeviceType(DEVICE_CPU).type()); + bool non_cpu_dst = (!dst_alloc_attr.on_host() && + dst_device_type != DeviceType(DEVICE_CPU).type()); + if (non_cpu_src) { + gpu::Stream* stream = send_dev_context->stream(); + if (stream == nullptr) { + done(errors::Internal("Failed to find device stream")); + return; + } + auto* src_dev_info = src->tensorflow_gpu_device_info(); + CHECK(src_dev_info); + + if (non_cpu_dst) { + // Device to device copy + DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); + stream->ThenMemcpy( + &gpu_dst_ptr, + DeviceMemoryBase{const_cast(src_ptr), total_bytes}, + total_bytes); + if (dst_device_type == DeviceType(DEVICE_GPU).type()) { + // Use of input may outlive stack scope, so keep a ref. + Tensor* input_ref = new Tensor(*input); + src_dev_info->event_mgr->ThenExecute( + stream, [done, stream, input_ref]() { + delete input_ref; + if (!stream->ok()) { + done(errors::Internal("GPU->GPU Memcpy failed")); + } else { + done(Status::OK()); + } + }); + } + send_dev_context->MaintainLifetimeOnStream(input, stream); + } else { + // Device to host copy. + return send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, + output, done); + } + } else if (non_cpu_dst) { + // Host to Device copy. + // Note that this is already an async copy. + recv_dev_context->CopyCPUTensorToDevice(input, dst, output, done); + } else { + memcpy(dst_ptr, src_ptr, total_bytes); + done(Status::OK()); + } + } else { + // buffer is empty + done(Status::OK()); + } +} + +void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device, + const DeviceContext* device_context, + const Tensor* gpu_tensor, Tensor* cpu_tensor, + StatusCallback done) { + VLOG(1) << "CopyGPUTensorToCPU"; + size_t total_bytes = gpu_tensor->TotalBytes(); + // Note that 0-size tensors have no backing buffer. + if (total_bytes > 0) { + const void* src_ptr = DMAHelper::base(gpu_tensor); + void* dst_ptr = DMAHelper::base(cpu_tensor); + CHECK(dst_ptr); + auto* stream = gpu_device->tensorflow_gpu_device_info()->stream; + if (device_context) { + stream = static_cast(device_context)->stream(); + } + stream->ThenMemcpy( + dst_ptr, DeviceMemoryBase{const_cast(src_ptr), total_bytes}, + total_bytes); + stream->BlockHostUntilDone(); + if (!stream->ok()) { + done(errors::Internal("CopyGPUTensorToCPU: GPU->CPU Memcpy failed")); + return; + } + } + + done(Status::OK()); +} + +/* static */ +void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor, + const DeviceContext* device_context, + Device* gpu_device, Tensor* gpu_tensor, + StatusCallback done) { + VLOG(1) << "CopyCPUTensorToGPU"; + CHECK(DeviceType(gpu_device->attributes().device_type()) == + DeviceType(DEVICE_GPU)); + + auto* dev_info = gpu_device->tensorflow_gpu_device_info(); + if (!dev_info) { + done(errors::Internal("Failed to find dest device GPUDeviceInfo")); + return; + } + if (cpu_tensor->TotalBytes() != gpu_tensor->TotalBytes()) { + done(errors::Internal( + strings::StrCat("Can't copy ", cpu_tensor->TotalBytes(), + " bytes of a tensor into another with ", + gpu_tensor->TotalBytes(), " bytes buffer."))); + return; + } + const int64 total_bytes = cpu_tensor->TotalBytes(); + // Note that 0-size tensors have no backing buffer. + if (total_bytes > 0) { + const void* src_ptr = DMAHelper::base(cpu_tensor); + void* dst_ptr = DMAHelper::base(gpu_tensor); + DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); + + CHECK(device_context); + auto* stream = + static_cast(device_context)->stream(); + stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes); + auto* dev_info = gpu_device->tensorflow_gpu_device_info(); + // Use of cpu_tensor may outlive stack scope, so keep a ref. + Tensor* input_ref = new Tensor(*cpu_tensor); + dev_info->event_mgr->ThenExecute(stream, [stream, done, input_ref]() { + delete input_ref; + if (!stream->ok()) { + done(errors::Internal("CopyCPUTensorToGPU: GPU Memcpy failed")); + } else { + done(Status::OK()); + } + }); + } else { + // empty tensor case + done(Status::OK()); + } +} + +Status GPUUtil::Sync(Device* gpu_device) { + VLOG(1) << "GPUUtil::Sync"; + auto* dev_info = gpu_device->tensorflow_gpu_device_info(); + if (!dev_info) { + return errors::Internal("Failed to find dest device GPUDeviceInfo"); + } + dev_info->stream->BlockHostUntilDone(); + if (!dev_info->stream->ok()) { + LOG(FATAL) << "GPU sync failed"; + } + return Status::OK(); +} + +Status GPUUtil::SyncAll(Device* gpu_device) { + VLOG(1) << "GPUUtil::SyncAll"; + auto* dev_info = gpu_device->tensorflow_gpu_device_info(); + if (!dev_info) { + return errors::Internal("Failed to find dest device GPUDeviceInfo"); + } + if (!dev_info->stream->parent()->SynchronizeAllActivity() || + !dev_info->stream->ok()) { + LOG(FATAL) << "GPU sync failed"; + } + return Status::OK(); +} + +string GPUUtil::MemoryDebugString(const Device* device, Tensor* tensor) { + string ret; + CHECK(tensor); + const int64 num_bytes = std::min( + FLAGS_brain_gpu_util_debug_string_maxlen, tensor->TotalBytes()); + void* ptr = (num_bytes > 0) ? DMAHelper::base(tensor) : nullptr; + strings::Appendf(&ret, "%p:", ptr); + if (num_bytes > 0) { + auto* dev_info = device->tensorflow_gpu_device_info(); + if (!dev_info) { + strings::StrAppend( + &ret, PrintMemory(reinterpret_cast(ptr), num_bytes)); + } else { + string buf; + buf.resize(num_bytes); + DeviceMemoryBase gpu_ptr(ptr, num_bytes); + Status s = dev_info->stream->parent()->SynchronousMemcpyD2H( + gpu_ptr, num_bytes, gtl::string_as_array(&buf)); + strings::StrAppend(&ret, + PrintMemory(gtl::string_as_array(&buf), num_bytes)); + } + } + return ret; +} + +// TODO(pbar) Checksum is called from places without a valid device context. +uint64 GPUUtil::Checksum(Device* gpu_device, + const DeviceContext* device_context, + const Tensor& tensor) { + Tensor copy(tensor.dtype(), tensor.shape()); + Status s; + Notification n; + CopyGPUTensorToCPU(gpu_device, device_context, &tensor, ©, + [&s, &n](Status status) { + s.Update(status); + n.Notify(); + }); + n.WaitForNotification(); + CHECK(s.ok()) << s; + return Checksum(copy); +} + +uint64 GPUUtil::Checksum(const Tensor& tensor) { + const float* fptr = reinterpret_cast(DMAHelper::base(&tensor)); + size_t num_bytes = tensor.TotalBytes(); + size_t num_floats = num_bytes / sizeof(float); + for (size_t i = 0; i < num_floats; ++i) { + CHECK(!std::isnan(fptr[i])) << " i " << i; + } + // TODO(tucker): consider using crc32c instead. + return Hash64(reinterpret_cast(DMAHelper::base(&tensor)), + tensor.TotalBytes(), 0); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.h b/tensorflow/core/common_runtime/gpu/gpu_util.h new file mode 100644 index 0000000000..1d8c3a054d --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_util.h @@ -0,0 +1,89 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/common_runtime/gpu/dma_helper.h" +#include "tensorflow/stream_executor/device_memory.h" + +#include "tensorflow/stream_executor/stream.h" + +namespace tensorflow { + +class RecvTensorResponse; +class TensorProto; + +namespace gpu = ::perftools::gputools; + +class GPUUtil { + public: + // "tensor" is GPU-local. "dev" is the hosting GPU. + // "device_context" should be the context of the GPU "_Send" op + // which provides the Tensor. + // Sets all necessasry fields of "proto" by transferring value + // bytes from GPU to CPU RAM. "is_dead" indicates that the + // tensor is dead with an uninit value. + static void SetProtoFromGPU(const Tensor& tensor, Device* dev, + const DeviceContext* device_context, + TensorProto* proto, bool is_dead, + StatusCallback done); + + // Copies "input" to "output" between devices accessible to the + // local process via some DMA-like method. "edge_name" is the name + // of the tensor being copied, for debugging purposes. Depending on + // the type of devices and memory in use, the copy may be performed + // synchronously or asynchronously. 'done' will be invoked only + // after the copy is actually complete. + static void CopyViaDMA(const string& edge_name, + DeviceContext* send_dev_context, + DeviceContext* recv_dev_context, Device* src, + Device* dst, const AllocatorAttributes src_alloc_attr, + const AllocatorAttributes dst_alloc_attr, + const Tensor* input, Tensor* output, + StatusCallback done); + + // Copies the data in 'gpu_tensor' into 'cpu_tensor'. + // 'gpu_tensor''s backing memory must be on 'gpu_device' and + // 'cpu_tensor' must be allocated to be of the same size as + // 'gpu_tensor'. Synchronous: may block. + static void CopyGPUTensorToCPU(Device* gpu_device, + const DeviceContext* device_context, + const Tensor* gpu_tensor, Tensor* cpu_tensor, + StatusCallback done); + + // Blocks until all operations queued on the stream associated with + // "gpu_device" at the time of the call have completed. Returns any + // error pending on the stream at completion. + static Status Sync(Device* gpu_device); + + // Blocks until all operations queued on all streams associated with the + // corresponding GPU device at the time of call have completed. + // Returns any error pending on the stream at completion. + static Status SyncAll(Device* gpu_device); + + // For debugging purpose, given a "device" and a "tensor" allocated + // on the device, return a string printing each byte in the tensor + // (up to a limit). "device" can be either a CPU or a GPU device. + static string MemoryDebugString(const Device* device, Tensor* tensor); + + static perftools::gputools::DeviceMemory AsGPUFloat(const Tensor& t); + + // Computes a checksum over the contents of "tensor", which is allocated + // on "gpu_device". + static uint64 Checksum(Device* gpu_device, + const DeviceContext* device_context, + const Tensor& tensor); + + // Computes a checksum over the contents of "tensor", which is allocated + // in local CPU RAM. + static uint64 Checksum(const Tensor& tensor); + + static void CopyCPUTensorToGPU(const Tensor* cpu_tensor, + const DeviceContext* device_context, + Device* gpu_device, Tensor* gpu_tensor, + StatusCallback done); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc new file mode 100644 index 0000000000..f1b1174a28 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc @@ -0,0 +1,24 @@ +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/common_runtime/gpu/gpu_util.h" +#include "tensorflow/stream_executor/stream.h" + +namespace tensorflow { + +void GPUDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, + Device* device, + Tensor* device_tensor, + StatusCallback done) const { + GPUUtil::CopyCPUTensorToGPU(cpu_tensor, this, device, device_tensor, done); +} + +void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, + const string& tensor_name, + Device* device, Tensor* cpu_tensor, + StatusCallback done) { + GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.cc b/tensorflow/core/common_runtime/gpu/pool_allocator.cc new file mode 100644 index 0000000000..52deb7fce2 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/pool_allocator.cc @@ -0,0 +1,269 @@ +#include "tensorflow/core/common_runtime/gpu/pool_allocator.h" + +#include +#include +#include // for munmap + +#include + +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +//#include "prodkernel/api/base/numa.h" + +namespace tensorflow { + +PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize, + SubAllocator* allocator, + RoundUpInterface* size_rounder, string name) + : name_(name), + has_size_limit_(pool_size_limit > 0), + auto_resize_(auto_resize), + pool_size_limit_(pool_size_limit), + allocator_(allocator), + size_rounder_(size_rounder), + allocation_begun_(false) { + if (auto_resize) { + CHECK_LT(0, pool_size_limit) + << "size limit must be > 0 if auto_resize is true."; + } +} + +PoolAllocator::~PoolAllocator() { Clear(); } + +namespace { +// Pools contain Chunks allocatated from the underlying Allocator. +// Chunk alignment is always on kPoolAlignment boundaries. Each Chunk +// begins with a descriptor (ChunkPrefix) that gives its size and a +// pointer to itself. The pointer returned to the user is just past +// the ChunkPrefix. If the user asks for a larger alignment, we will +// increase the size of the chunk, then adjust the returned user +// pointer and also re-write the ChunkPrefix.chunk_ptr value +// immediately before it. This way the Chunk address and size can be +// recovered from the returned user pointer, regardless of alignment. +// Note that this deferencing of the pointers means that we cannot +// handle GPU memory, only CPU memory. +struct ChunkPrefix { + size_t num_bytes; + void* chunk_ptr; +}; +// kPoolAlignment cannot be less than the size of ChunkPrefix. +static const int kPoolAlignment = sizeof(ChunkPrefix); + +void* PrepareChunk(void* chunk, size_t alignment, size_t num_bytes) { + ChunkPrefix* cp = reinterpret_cast(chunk); + cp->num_bytes = num_bytes; + cp->chunk_ptr = chunk; + void* user_ptr = reinterpret_cast(cp + 1); + if (alignment > kPoolAlignment) { + // Move user_ptr forward to the first satisfying offset, and write + // chunk_ptr just before it. + size_t aligned_ptr = reinterpret_cast(user_ptr) + alignment; + user_ptr = reinterpret_cast(aligned_ptr & ~(alignment - 1)); + (reinterpret_cast(user_ptr) - 1)->chunk_ptr = chunk; + } + // Safety check that user_ptr is always past the ChunkPrefix. + CHECK_GE(user_ptr, reinterpret_cast(chunk) + 1); + return user_ptr; +} + +ChunkPrefix* FindPrefix(void* user_ptr) { + ChunkPrefix* cp = reinterpret_cast(user_ptr) - 1; + return reinterpret_cast(cp->chunk_ptr); +} +} // namespace + +void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { + if (!allocation_begun_) allocation_begun_ = true; + if (num_bytes == 0) return nullptr; + + // If alignment is larger than kPoolAlignment, increase num_bytes so that we + // are guaranteed to be able to return an aligned ptr by advancing user_ptr + // without overrunning the end of the chunk. + if (alignment > kPoolAlignment) { + num_bytes += alignment; + } + num_bytes += sizeof(ChunkPrefix); + num_bytes = size_rounder_->RoundUp(num_bytes); + PtrRecord* pr = nullptr; + if (has_size_limit_) { + { + mutex_lock lock(mutex_); + auto iter = pool_.find(num_bytes); + if (iter == pool_.end()) { + allocated_count_++; + // Deliberately fall out of lock scope before + // calling the allocator. No further modification + // to the pool will be performed. + } else { + get_from_pool_count_++; + pr = iter->second; + RemoveFromList(pr); + pool_.erase(iter); + // Fall out of lock scope and do the result without the lock held. + } + } + } + if (pr != nullptr) { + void* r = pr->ptr; + delete pr; + return PrepareChunk(r, alignment, num_bytes); + } else { + void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes); + for (auto v : alloc_visitors_) { + v(ptr, num_bytes); + } + return PrepareChunk(ptr, alignment, num_bytes); + } +} + +void PoolAllocator::DeallocateRaw(void* ptr) { + if (ptr == nullptr) return; + ChunkPrefix* cp = FindPrefix(ptr); + CHECK_LE((void*)cp, (void*)ptr); + if (!has_size_limit_ && !auto_resize_) { + for (auto v : free_visitors_) { + v(cp, cp->num_bytes); + } + allocator_->Free(cp, cp->num_bytes); + } else { + mutex_lock lock(mutex_); + ++put_count_; + while (pool_.size() >= pool_size_limit_) { + EvictOne(); + } + PtrRecord* pr = new PtrRecord; + pr->num_bytes = cp->num_bytes; + pr->ptr = cp; + AddToList(pr); + pool_.insert(std::make_pair(cp->num_bytes, pr)); + } +} + +void PoolAllocator::Clear() { + if (has_size_limit_) { + mutex_lock lock(mutex_); + for (auto iter : pool_) { + PtrRecord* pr = iter.second; + for (auto v : free_visitors_) { + v(pr->ptr, pr->num_bytes); + } + allocator_->Free(pr->ptr, pr->num_bytes); + delete pr; + } + pool_.clear(); + get_from_pool_count_ = 0; + put_count_ = 0; + allocated_count_ = 0; + evicted_count_ = 0; + lru_head_ = nullptr; + lru_tail_ = nullptr; + } +} + +void PoolAllocator::RemoveFromList(PtrRecord* pr) { + if (pr->prev == nullptr) { + DCHECK_EQ(lru_head_, pr); + lru_head_ = nullptr; + } else { + pr->prev->next = pr->next; + } + if (pr->next == nullptr) { + DCHECK_EQ(lru_tail_, pr); + lru_tail_ = pr->prev; + } else { + pr->next->prev = pr->prev; + if (lru_head_ == nullptr) { + lru_head_ = pr->next; + } + } +} + +void PoolAllocator::AddToList(PtrRecord* pr) { + pr->prev = nullptr; + if (lru_head_ == nullptr) { + CHECK(lru_tail_ == nullptr); + lru_tail_ = pr; + pr->next = nullptr; + } else { + pr->next = lru_head_; + pr->next->prev = pr; + } + lru_head_ = pr; +} + +void PoolAllocator::EvictOne() { + DCHECK(lru_tail_ != nullptr); + PtrRecord* prec = lru_tail_; + RemoveFromList(prec); + auto iter = pool_.find(prec->num_bytes); + while (iter->second != prec) { + ++iter; + DCHECK(iter != pool_.end()); + } + pool_.erase(iter); + for (auto v : free_visitors_) { + v(prec->ptr, prec->num_bytes); + } + allocator_->Free(prec->ptr, prec->num_bytes); + delete prec; + ++evicted_count_; + // Auto-resizing, and warning messages. + static const double kTolerable = 2e-3; + static const int kCheckInterval = 1000; + static const double kIncreaseFactor = 1.1; + static const int kMinPoolSize = 100; + if (0 == evicted_count_ % kCheckInterval) { + const double eviction_rate = + evicted_count_ / static_cast(put_count_); + const int64 alloc_request_count = allocated_count_ + get_from_pool_count_; + const double alloc_rate = + allocated_count_ / static_cast(alloc_request_count); + static int log_counter = 0; + // (counter increment not thread safe but it's just for logging, so we + // don't care). + bool should_log = ((log_counter++ % 10) == 0); + if (should_log) { + LOG(WARNING) << "PoolAllocator: After " << alloc_request_count + << " get requests, put_count=" << put_count_ + << " evicted_count=" << evicted_count_ + << " eviction_rate=" << eviction_rate + << " and unsatisfied allocation rate=" << alloc_rate; + } + if (auto_resize_ && (eviction_rate > kTolerable) && + (alloc_rate > kTolerable)) { + size_t new_size_limit = (pool_size_limit_ < kMinPoolSize) + ? kMinPoolSize + : (kIncreaseFactor * pool_size_limit_); + if (should_log) { + LOG(INFO) << "Raising pool_size_limit_ from " << pool_size_limit_ + << " to " << new_size_limit; + } + pool_size_limit_ = new_size_limit; + // Reset all the counters so that ratios are relative to new sizes + // at next test interval. + put_count_ = 0; + allocated_count_ = 0; + evicted_count_ = 0; + get_from_pool_count_ = 0; + } + } +} + +void PoolAllocator::AddAllocVisitor(Visitor visitor) { + mutex_lock lock(mutex_); + CHECK(!allocation_begun_) + << "AddAllocVisitor may not be called after pool allocation " + << "has begun."; + alloc_visitors_.push_back(visitor); +} + +void PoolAllocator::AddFreeVisitor(Visitor visitor) { + mutex_lock lock(mutex_); + CHECK(!allocation_begun_) + << "AddFreeVisitor may not be called after pool allocation " + << "has begun."; + free_visitors_.push_back(visitor); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.h b/tensorflow/core/common_runtime/gpu/pool_allocator.h new file mode 100644 index 0000000000..d10aabe88a --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/pool_allocator.h @@ -0,0 +1,202 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_POOL_ALLOCATOR_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_POOL_ALLOCATOR_H_ + +// Simple LRU pool allocators for various flavors of CPU RAM that +// implement the VisitableAllocator interface. GPU memory is managed +// by GPURegionAllocator. + +#include +#include +#include +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h" +#include "tensorflow/stream_executor/stream_executor.h" + +namespace tensorflow { + +// Interface of an object that does the underlying alloc/free of memory. +class SubAllocator { + public: + virtual ~SubAllocator() {} + virtual void* Alloc(size_t alignment, size_t num_bytes) = 0; + virtual void Free(void* ptr, size_t num_bytes) = 0; +}; + +// Interface of an object that rounds up integers. +class RoundUpInterface { + public: + virtual ~RoundUpInterface() {} + virtual size_t RoundUp(size_t num_bytes) = 0; +}; + +// Size-limited pool of memory buffers obtained from a SubAllocator +// instance. Pool eviction policy is LRU. +class PoolAllocator : public VisitableAllocator { + public: + // "pool_size_limit" is the maximum number of returned, re-usable + // memory buffers to keep in the pool. If pool_size_limit == 0, the + // pool is effectively a thin wrapper around the allocator. + // If "auto_resize" is true, then the pool_size_limit will gradually + // be raised so that deallocations happen very rarely, if at all. + // Transitory start-up objects may deallocate, but the long-term + // working-set should not. Auto-resizing can raise pool_size_limit + // but will never lower it. + // "allocator" is the object that performs the underlying memory + // malloc/free operations. This object takes ownership of allocator. + PoolAllocator(size_t pool_size_limit, bool auto_resize, + SubAllocator* allocator, RoundUpInterface* size_rounder, + string name); + ~PoolAllocator() override; + + string Name() override { return name_; } + + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + + void DeallocateRaw(void* ptr) override; + + // REQUIRES: The following functions may only be called prior + // to the first Allocate*() call. Once allocation has begun, it is + // illegal to register another visitor. + + void AddAllocVisitor(Visitor visitor) override; + + void AddFreeVisitor(Visitor visitor) override; + + // Allocate an unused memory region of size "num_bytes". Fetch from + // the pool if available, otherwise call allocator_. + void* Get(size_t num_bytes); + + // Return a no-longer needed memory region to the pool. It is an error + // to deference "ptr" after this call. If the pool is full, the least + // recently used region will be deallocated. + void Put(void* ptr, size_t num_bytes); + + // Reset the pool to empty. + void Clear(); + + // The following accessors permit monitoring the effectiveness of + // the pool at avoiding repeated malloc/frees on the underlying + // allocator. Read locks are not taken on the theory that value + // consistency with other threads is not important. + + // Number of Get() requests satisfied from pool. + int64 get_from_pool_count() const NO_THREAD_SAFETY_ANALYSIS { + return get_from_pool_count_; + } + // Number of Put() requests. + int64 put_count() const NO_THREAD_SAFETY_ANALYSIS { return put_count_; } + // Number of Get() requests requiring a fresh allocation. + int64 allocated_count() const NO_THREAD_SAFETY_ANALYSIS { + return allocated_count_; + } + // Number of pool evictions. + int64 evicted_count() const NO_THREAD_SAFETY_ANALYSIS { + return evicted_count_; + } + // Current size limit. + size_t size_limit() const NO_THREAD_SAFETY_ANALYSIS { + return pool_size_limit_; + } + + private: + struct PtrRecord { + void* ptr; + size_t num_bytes; + PtrRecord* prev; + PtrRecord* next; + }; + + // Remove "pr" from the double-linked LRU list. + void RemoveFromList(PtrRecord* pr) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Add "pr" to the head of the double-linked LRU list. + void AddToList(PtrRecord* pr) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Delete the least recently used record. + void EvictOne() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + const string name_; + const bool has_size_limit_; + const bool auto_resize_; + size_t pool_size_limit_; + std::unique_ptr allocator_; + std::unique_ptr size_rounder_; + mutex mutex_; + std::multimap pool_ GUARDED_BY(mutex_); + PtrRecord* lru_head_ GUARDED_BY(mutex_) = nullptr; + PtrRecord* lru_tail_ GUARDED_BY(mutex_) = nullptr; + int64 get_from_pool_count_ GUARDED_BY(mutex_) = 0; + int64 put_count_ GUARDED_BY(mutex_) = 0; + int64 allocated_count_ GUARDED_BY(mutex_) = 0; + int64 evicted_count_ GUARDED_BY(mutex_) = 0; + // Write access to these is guarded by mutex_, but not read + // access. They may only be modified prior to the first + // allocation. Later attempts to modify will fail. + std::vector alloc_visitors_; + std::vector free_visitors_; + std::atomic allocation_begun_; +}; + +// Do-nothing rounder. Passes through sizes unchanged. +class NoopRounder : public RoundUpInterface { + public: + size_t RoundUp(size_t num_bytes) override { return num_bytes; } +}; + +// Power of 2 rounder: rounds up to nearest power of 2 size. +class Pow2Rounder : public RoundUpInterface { + public: + size_t RoundUp(size_t num_bytes) override { + return 1uLL << Log2Ceiling64(num_bytes); + } +}; + +class BasicCPUAllocator : public SubAllocator { + public: + ~BasicCPUAllocator() override {} + + void* Alloc(size_t alignment, size_t num_bytes) override { + return port::aligned_malloc(num_bytes, alignment); + } + void Free(void* ptr, size_t num_bytes) override { free(ptr); } +}; + +// Allocator for pinned CPU RAM that is made known to CUDA for the +// purpose of efficient DMA with a GPU. +class CUDAHostAllocator : public SubAllocator { + public: + // Note: stream_exec cannot be null. + explicit CUDAHostAllocator(perftools::gputools::StreamExecutor* stream_exec) + : stream_exec_(stream_exec) { + CHECK(stream_exec_ != nullptr); + } + ~CUDAHostAllocator() override {} + + void* Alloc(size_t alignment, size_t num_bytes) override { + void* ptr = nullptr; + if (num_bytes > 0) { + ptr = stream_exec_->HostMemoryAllocate(num_bytes); + if (ptr == nullptr) { + LOG(FATAL) << "could not allocate pinned host memory of size: " + << num_bytes; + } + } + return ptr; + } + + void Free(void* ptr, size_t num_bytes) override { + if (ptr != nullptr) { + stream_exec_->HostMemoryDeallocate(ptr); + } + } + + private: + perftools::gputools::StreamExecutor* stream_exec_; // not owned, non-null + + TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_POOL_ALLOCATOR_H_ diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc new file mode 100644 index 0000000000..ca409b2b4c --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc @@ -0,0 +1,203 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/common_runtime/gpu/pool_allocator.h" + +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/platform.h" +#include + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { +namespace { + +TEST(PoolAllocatorTest, ZeroSizeBuffers) { + gpu::Platform* platform = + gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie(); + PoolAllocator pool( + 2 /*pool_size_limit*/, false /*auto_resize*/, + new CUDAHostAllocator( + platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0)) + .ValueOrDie()), + new NoopRounder, "pool"); + + EXPECT_EQ(nullptr, pool.AllocateRaw(4 /*alignment*/, 0 /*num_bytes*/)); + pool.DeallocateRaw(nullptr); // Should not crash. + EXPECT_EQ(0, pool.get_from_pool_count()); + EXPECT_EQ(0, pool.put_count()); + EXPECT_EQ(0, pool.allocated_count()); + EXPECT_EQ(0, pool.evicted_count()); +} + +TEST(PoolAllocatorTest, ZeroSizePool) { + gpu::Platform* platform = + gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie(); + PoolAllocator pool( + 0 /*pool_size_limit*/, false /*auto_resize*/, + new CUDAHostAllocator( + platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0)) + .ValueOrDie()), + new NoopRounder, "pool"); + + EXPECT_EQ(0, pool.get_from_pool_count()); + EXPECT_EQ(0, pool.put_count()); + EXPECT_EQ(0, pool.allocated_count()); + EXPECT_EQ(0, pool.evicted_count()); + + // All allocations should bypass the pool and return valid pointers. + for (int i = 0; i < 3; ++i) { + void* p0 = pool.AllocateRaw(4, 0); + void* p4 = pool.AllocateRaw(4, 4); + void* p12 = pool.AllocateRaw(4, 12); + EXPECT_EQ(nullptr, p0); + EXPECT_NE(nullptr, p4); + EXPECT_NE(nullptr, p12); + pool.DeallocateRaw(p0); + pool.DeallocateRaw(p4); + pool.DeallocateRaw(p12); + } + EXPECT_EQ(0, pool.get_from_pool_count()); + EXPECT_EQ(0, pool.put_count()); + EXPECT_EQ(0, pool.allocated_count()); + EXPECT_EQ(0, pool.evicted_count()); +} + +TEST(PoolAllocatorTest, Alignment) { + gpu::Platform* platform = + gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie(); + PoolAllocator pool( + 0 /*pool_size_limit*/, false /*auto_resize*/, + new CUDAHostAllocator( + platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0)) + .ValueOrDie()), + new NoopRounder, "pool"); + for (int i = 0; i < 16; ++i) { + size_t alignment = 1 << i; + void* p = pool.AllocateRaw(alignment, 111); + EXPECT_TRUE(p != nullptr); + EXPECT_EQ(0, reinterpret_cast(p) & (alignment - 1)) + << "ptr: " << p << " alignment " << alignment; + // Intentionally don't deallocate, to test that destruction of + // the PoolAllocator frees all pending memory. + } +} + +TEST(PoolAllocatorTest, AutoResize) { + PoolAllocator pool(2 /*pool_size_limit*/, true /*auto_resize*/, + new BasicCPUAllocator, new NoopRounder, "pool"); + + // Alloc/dealloc 10 sizes just a few times, confirming pool size + // stays at 2. + for (int i = 0; i < 10; ++i) { + void* p = pool.AllocateRaw(4, 64 << i); + pool.DeallocateRaw(p); + } + EXPECT_EQ(0, pool.get_from_pool_count()); + EXPECT_EQ(10, pool.allocated_count()); + EXPECT_EQ(10, pool.put_count()); + EXPECT_EQ(8, pool.evicted_count()); + EXPECT_EQ(2, pool.size_limit()); + + // Then repeat 1200 times. Pool size limit should jump to 100. + for (int j = 0; j < 120; ++j) { + for (int i = 0; i < 10; ++i) { + void* p = pool.AllocateRaw(4, 64 << i); + pool.DeallocateRaw(p); + } + } + EXPECT_EQ(100, pool.size_limit()); +} + +TEST(PoolAllocatorTest, CudaHostAllocator) { + gpu::Platform* platform = + gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie(); + PoolAllocator pool( + 2 /*pool_size_limit*/, false /*auto_resize*/, + new CUDAHostAllocator( + platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0)) + .ValueOrDie()), + new NoopRounder, "pool"); + + // Repeatedly Get a 16-byte value, confirming that there's only + // one real allocation. + void* p1_16 = pool.AllocateRaw(4, 16); + EXPECT_EQ(0, pool.get_from_pool_count()); + EXPECT_EQ(1, pool.allocated_count()); + EXPECT_NE(nullptr, p1_16); + pool.DeallocateRaw(p1_16); + // Pool contents {16} + EXPECT_EQ(1, pool.put_count()); + void* p2_16 = pool.AllocateRaw(4, 16); // Get it again. + EXPECT_EQ(1, pool.get_from_pool_count()); + EXPECT_EQ(1, pool.allocated_count()); + EXPECT_EQ(p1_16, p2_16); // Same pointer value + pool.DeallocateRaw(p2_16); // Put it back. + // Pool contents {16} + EXPECT_EQ(2, pool.put_count()); + + // Get two more values of different sizes. + void* p3_4 = pool.AllocateRaw(4, 4); + EXPECT_EQ(2, pool.allocated_count()); + EXPECT_NE(p1_16, p3_4); // Different pointer value + EXPECT_NE(nullptr, p3_4); + pool.DeallocateRaw(p3_4); // Put it back. Pool is now full. + // Pool contents {4, 16} + EXPECT_EQ(3, pool.put_count()); + void* p4_2 = pool.AllocateRaw(4, 2); // Get a third size buffer. + EXPECT_NE(nullptr, p4_2); + EXPECT_EQ(0, pool.evicted_count()); + + // The pool is full: when we put back p4_2, the 16-byte buffer + // should be evicted since it was least recently inserted. + pool.DeallocateRaw(p4_2); + // Pool contents {2, 4} + EXPECT_EQ(4, pool.put_count()); + EXPECT_EQ(1, pool.evicted_count()); + + // Re-getting and putting size 2 or 4 should not alter pool size or + // num-evicted. + void* p5_4 = pool.AllocateRaw(4, 4); + EXPECT_NE(nullptr, p5_4); + pool.DeallocateRaw(p5_4); + void* p6_2 = pool.AllocateRaw(4, 2); + EXPECT_NE(nullptr, p6_2); + pool.DeallocateRaw(p6_2); + EXPECT_EQ(3, pool.get_from_pool_count()); + EXPECT_EQ(6, pool.put_count()); + EXPECT_EQ(3, pool.allocated_count()); + EXPECT_EQ(1, pool.evicted_count()); + + pool.Clear(); + EXPECT_EQ(0, pool.get_from_pool_count()); + EXPECT_EQ(0, pool.put_count()); + EXPECT_EQ(0, pool.allocated_count()); + EXPECT_EQ(0, pool.evicted_count()); +} + +TEST(PoolAllocatorTest, Pow2Rounder) { + Pow2Rounder rounder; + EXPECT_EQ(1, rounder.RoundUp(1)); + EXPECT_EQ(2, rounder.RoundUp(2)); + EXPECT_EQ(16, rounder.RoundUp(9)); + EXPECT_EQ(16, rounder.RoundUp(16)); + EXPECT_EQ(65536, rounder.RoundUp(41234)); + EXPECT_EQ(65536, rounder.RoundUp(65535)); + EXPECT_EQ(65536, rounder.RoundUp(65536)); +} + +TEST(PoolAllocatorTest, Name) { + gpu::Platform* platform = + gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie(); + PoolAllocator pool( + 2 /*pool_size_limit*/, false /*auto_resize*/, + new CUDAHostAllocator( + platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0)) + .ValueOrDie()), + new NoopRounder, "pool"); + EXPECT_EQ("pool", pool.Name()); +} + +} // namespace +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc new file mode 100644 index 0000000000..70ac6130c2 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/process_state.cc @@ -0,0 +1,220 @@ +#include "tensorflow/core/common_runtime/gpu/process_state.h" + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h" +#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h" +#include "tensorflow/core/common_runtime/gpu/gpu_region_allocator.h" +#include "tensorflow/core/common_runtime/gpu/pool_allocator.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" + +#if defined(PLATFORM_GOOGLE) +DEFINE_bool(record_mem_types, false, + "If true, record attributes of memory allocations and " + "dyanmically check for appropriate use of registered memory." + "Should only be true for debugging or diagnosis of " + "performance issues."); +DEFINE_bool(brain_mem_reg_cuda_dma, true, + "If true, register CPU RAM used to copy to/from GPU RAM " + "with the CUDA driver."); +DEFINE_bool(brain_gpu_use_bfc_allocator, false, + "If true, uses the Best-Fit GPU allocator."); +DEFINE_bool(brain_gpu_region_allocator_debug, false, + "If true, checks for memory overwrites by writing " + "distinctive patterns on both ends of allocated memory."); +DEFINE_bool(brain_gpu_region_allocator_reset_to_nan, false, + "If true, initializes all new Malloc buffers to NaN, " + "and resets the buffer to NaN upon Free."); + +#else +bool FLAGS_record_mem_types = false; +bool FLAGS_brain_mem_reg_cuda_dma = true; +bool FLAGS_brain_gpu_region_allocator_debug = false; +bool FLAGS_brain_gpu_region_allocator_reset_to_nan = false; +bool FLAGS_brain_gpu_use_bfc_allocator = false; +#endif + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { + +ProcessState* ProcessState::instance_ = nullptr; + +/*static*/ ProcessState* ProcessState::singleton() { + if (instance_ == nullptr) { + instance_ = new ProcessState; + } + + return instance_; +} + +ProcessState::ProcessState() : gpu_count_(0) { + CHECK(instance_ == nullptr); + instance_ = this; +} + +ProcessState::~ProcessState() { + for (auto p : gpu_allocators_) { + delete p; + } + instance_ = nullptr; +} + +string ProcessState::MemDesc::DebugString() { + return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index, ", dma: ", + gpu_registered, ", nic: ", nic_registered); +} + +ProcessState::MemDesc ProcessState::PtrType(const void* ptr) { + if (FLAGS_record_mem_types) { + auto iter = mem_desc_map_.find(ptr); + if (iter != mem_desc_map_.end()) { + return iter->second; + } + } + return MemDesc(); +} + +void ProcessState::SetGPUCount(int c) { + CHECK(gpu_count_ == 0 || gpu_count_ == c) + << "Cannot call SetGPUCount with a non-zero value " + << "not equal to prior set value."; + gpu_count_ = c; +} + +int ProcessState::GPUCount() const { return gpu_count_; } + +Allocator* ProcessState::GetGPUAllocator(int gpu_id, size_t total_bytes) { +#if GOOGLE_CUDA + mutex_lock lock(mu_); + gpu::Platform* gpu_platform = GPUMachineManager(); + + // Verify that gpu_id is legitimate. + CHECK_LT(gpu_id, gpu_platform->VisibleDeviceCount()) + << "gpu_id is outside discovered device range"; + + if (gpu_id >= static_cast(gpu_allocators_.size())) { + gpu_allocators_.resize(gpu_id + 1); + if (FLAGS_record_mem_types) gpu_al_.resize(gpu_id + 1); + } + + if (gpu_allocators_[gpu_id] == nullptr) { + VisitableAllocator* gpu_allocator; + + if (FLAGS_brain_gpu_use_bfc_allocator) { + gpu_allocator = new GPUBFCAllocator(gpu_id, total_bytes); + } else { + gpu_allocator = new GPURegionAllocator(gpu_id, total_bytes); + } + + if (FLAGS_brain_gpu_region_allocator_debug) { + gpu_allocator = new GPUDebugAllocator(gpu_allocator, gpu_id); + } + if (FLAGS_brain_gpu_region_allocator_reset_to_nan) { + gpu_allocator = new GPUNanResetAllocator(gpu_allocator, gpu_id); + } + + gpu_allocators_[gpu_id] = gpu_allocator; + + // If there are any pending AllocVisitors for this bus, add + // them now. + gpu::StreamExecutor* se = + gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie(); + int bus_id = se->GetDeviceDescription().numa_node(); + if (bus_id < static_cast(gpu_visitors_.size())) { + for (auto v : gpu_visitors_[bus_id]) { + gpu_allocators_[gpu_id]->AddAllocVisitor(v); + } + } + if (FLAGS_record_mem_types) { + MemDesc md; + md.loc = MemDesc::GPU; + md.dev_index = gpu_id; + md.gpu_registered = false; + md.nic_registered = true; + if (static_cast(gpu_al_.size()) <= gpu_id) + gpu_al_.resize(gpu_id + 1); + gpu_al_[gpu_id] = new internal::RecordingAllocator( + &mem_desc_map_, gpu_allocators_[gpu_id], md, &mu_); + } + } + if (FLAGS_record_mem_types) return gpu_al_[gpu_id]; + return gpu_allocators_[gpu_id]; +#else + LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda."; + return nullptr; +#endif // GOOGLE_CUDA +} + +Allocator* ProcessState::GetCPUAllocator(int numa_node) { + // Although we're temporarily ignoring numa_node, check for legality. + CHECK_GE(numa_node, 0); + // TODO(tucker): actually maintain separate CPUAllocators for + // different numa_nodes. For now, just one. + numa_node = 0; + mutex_lock lock(mu_); + while (cpu_allocators_.size() <= static_cast(numa_node)) { + cpu_allocators_.push_back(new PoolAllocator( + 100 /*pool_size_limit*/, true /*auto_resize*/, new BasicCPUAllocator(), + new NoopRounder, "cpu_pool")); + } + return cpu_allocators_[0]; +} + +Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) { + if (gpu_count_ == 0 || !FLAGS_brain_mem_reg_cuda_dma) { + return GetCPUAllocator(numa_node); + } + // Although we're temporarily ignoring numa_node, check for legality. + CHECK_GE(numa_node, 0); + // TODO(tucker): actually maintain separate CPUAllocators for + // different numa_nodes. For now, just one. + numa_node = 0; + mutex_lock lock(mu_); + while (static_cast(cuda_host_allocators_.size()) <= numa_node) { + // CUDAHost alloc the same across all gpus, so just get the + // executor for the first device. + gpu::Platform* gpu_platform = GPUMachineManager(); + gpu::StreamExecutor* se = gpu_platform->ExecutorForDevice(0).ValueOrDie(); + CHECK(se); + cuda_host_allocators_.push_back(new PoolAllocator( + 100 /*pool_size_limit*/, true /*auto_resize*/, + new CUDAHostAllocator(se), new Pow2Rounder, "cuda_host")); + if (FLAGS_record_mem_types) { + MemDesc md; + md.loc = MemDesc::CPU; + md.dev_index = 0; + md.gpu_registered = true; + md.nic_registered = false; + cuda_al_.push_back(new internal::RecordingAllocator( + &mem_desc_map_, cuda_host_allocators_.back(), md, &mu_)); + } + } + if (FLAGS_record_mem_types) return cuda_al_[0]; + return cuda_host_allocators_[0]; +} + +void ProcessState::AddGPUAllocVisitor(int bus_id, AllocVisitor visitor) { +#if GOOGLE_CUDA + mutex_lock lock(mu_); + gpu::Platform* gpu_platform = GPUMachineManager(); + for (int gpu_id = 0; gpu_id < static_cast(gpu_allocators_.size()); + ++gpu_id) { + gpu::StreamExecutor* se = + gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie(); + if (gpu_allocators_[gpu_id] && + se->GetDeviceDescription().numa_node() == bus_id) { + gpu_allocators_[gpu_id]->AddAllocVisitor(visitor); + } + } + while (bus_id >= static_cast(gpu_visitors_.size())) { + gpu_visitors_.push_back(std::vector()); + } + gpu_visitors_[bus_id].push_back(visitor); +#endif // GOOGLE_CUDA +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/process_state.h b/tensorflow/core/common_runtime/gpu/process_state.h new file mode 100644 index 0000000000..527d12c10d --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/process_state.h @@ -0,0 +1,140 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_PROCESS_STATE_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_PROCESS_STATE_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +class Allocator; +class VisitableAllocator; +class PoolAllocator; + +// Singleton that manages per-process state, e.g. allocation +// of shared resources. +class ProcessState { + public: + static ProcessState* singleton(); + + // Descriptor for memory allocation attributes, used by optional + // runtime correctness analysis logic. + struct MemDesc { + enum MemLoc { CPU, GPU }; + MemLoc loc; + int dev_index; + bool gpu_registered; + bool nic_registered; + MemDesc() + : loc(CPU), + dev_index(0), + gpu_registered(false), + nic_registered(false) {} + string DebugString(); + }; + + // Records the number of GPUs available in the local process. + // It is a fatal error to call this with a value != to the value + // in a prior call. + void SetGPUCount(int c); + + // Returns number of GPUs available in local process, as set by + // SetGPUCount(); Returns 0 if SetGPUCount has not been called. + int GPUCount() const; + + // Returns what we know about the memory at ptr. + // If we know nothing, it's called CPU 0 with no other attributes. + MemDesc PtrType(const void* ptr); + + // Returns the one CPUAllocator used for the given numa_node. + // TEMPORY: ignores numa_node. + Allocator* GetCPUAllocator(int numa_node); + + // Returns the one GPU allocator used for the indexed GPU. + // Note that this is a system GPU index, not (necessarily) a brain + // device index. + // + // 'total_bytes' is the total number of bytes that should be made + // available to the allocator. The first call to this function for + // a given gpu_id creates the allocator, so only the total_bytes + // used on that first call is used. + // + // REQUIRES: gpu_id must be a valid ordinal for a GPU available in the + // current system environment. Otherwise returns nullptr. + Allocator* GetGPUAllocator(int gpu_id, size_t total_bytes); + + Allocator* GetCUDAHostAllocator(int numa_node); + + // Registers a function to be called once on every new Region + // allocated by every GPURegionAllocator proximate to the specified + // bus. The AllocVisitor is provided with a memory pointer and the + // size of the area it identifies. The pointer is not guaranteed to + // be valid after the call terminates. The intention is for this + // interface to be used for network device memory registration. + // "bus_id" is platform-specific. On many platforms it + // should be 0. On machines with multiple PCIe buses, it should be + // the index of one of the PCIe buses. If the the bus_id is invalid, + // results are undefined. + typedef std::function AllocVisitor; + void AddGPUAllocVisitor(int bus_id, AllocVisitor visitor); + + typedef std::unordered_map MDMap; + + protected: + ProcessState(); + + static ProcessState* instance_; + + mutex mu_; + int gpu_count_; + + std::vector cpu_allocators_ GUARDED_BY(mu_); + std::vector gpu_allocators_ GUARDED_BY(mu_); + std::vector> gpu_visitors_ GUARDED_BY(mu_); + std::vector cuda_host_allocators_ GUARDED_BY(mu_); + + virtual ~ProcessState(); + + // Optional RecordingAllocators that wrap the corresponding + // Allocators for runtime attribute use analysis. + MDMap mem_desc_map_; + std::vector cpu_al_ GUARDED_BY(mu_); + std::vector gpu_al_ GUARDED_BY(mu_); + std::vector cuda_al_ GUARDED_BY(mu_); +}; + +namespace internal { +class RecordingAllocator : public Allocator { + public: + RecordingAllocator(ProcessState::MDMap* mm, Allocator* a, + ProcessState::MemDesc md, mutex* mu) + : mm_(mm), a_(a), md_(md), mu_(mu) {} + + string Name() override { return a_->Name(); } + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + void* p = a_->AllocateRaw(alignment, num_bytes); + mutex_lock l(*mu_); + (*mm_)[p] = md_; + return p; + } + void DeallocateRaw(void* p) override { + mutex_lock l(*mu_); + auto iter = mm_->find(p); + mm_->erase(iter); + a_->DeallocateRaw(p); + } + bool TracksAllocationSizes() override { return a_->TracksAllocationSizes(); } + size_t RequestedSize(void* p) override { return a_->RequestedSize(p); } + size_t AllocatedSize(void* p) override { return a_->AllocatedSize(p); } + ProcessState::MDMap* mm_; // not owned + Allocator* a_; // not owned + ProcessState::MemDesc md_; + mutex* mu_; +}; +} // namespace internal +} // namespace tensorflow +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_PROCESS_STATE_H_ diff --git a/tensorflow/core/common_runtime/gpu/visitable_allocator.h b/tensorflow/core/common_runtime/gpu/visitable_allocator.h new file mode 100644 index 0000000000..23feed9aab --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/visitable_allocator.h @@ -0,0 +1,30 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_VISITABLE_ALLOCATOR_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_VISITABLE_ALLOCATOR_H_ + +#include +#include "tensorflow/core/framework/allocator.h" + +namespace tensorflow { + +// Subclass VisitableAllocator instead of Allocator when a memory +// allocator needs to enable some kind of registration/deregistration +// of memory areas. +class VisitableAllocator : public Allocator { + public: + // Visitor gets called with a pointer to a memory area and its + // size in bytes. + typedef std::function Visitor; + + // Register a visitor guaranteed to be called exactly once on each + // chunk of memory newly allocated from the underlying device. + // Typically, chunks will be reused and possibly sub-divided by a + // pool manager, so the calls will happen only once per process + // execution, not once per tensor (re)allocation. + virtual void AddAllocVisitor(Visitor visitor) = 0; + + // Register a visitor guaranteed to be called on each chunk of + // memory returned to the underlying device. + virtual void AddFreeVisitor(Visitor visitor) = 0; +}; +} // namespace tensorflow +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_VISITABLE_ALLOCATOR_H_ diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h new file mode 100644 index 0000000000..03fd9a97c3 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu_device_context.h @@ -0,0 +1,45 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/device_base.h" + +namespace perftools { +namespace gputools { +class Stream; +} // namespace gputools +} // namespace perftools + +namespace tensorflow { + +namespace gpu = ::perftools::gputools; + +class GPUDeviceContext : public DeviceContext { + public: + GPUDeviceContext(int stream_id, gpu::Stream* stream) + : stream_id_(stream_id), stream_(stream) {} + + ~GPUDeviceContext() override {} + + gpu::Stream* stream() const override { return stream_; } + int stream_id() const { return stream_id_; } + + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, + StatusCallback done) const override; + + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + const string& edge_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; + + void MaintainLifetimeOnStream( + const Tensor* t, perftools::gputools::Stream* stream) const override {} + + private: + int stream_id_; + gpu::Stream* stream_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_ diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc new file mode 100644 index 0000000000..28afc95c1b --- /dev/null +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -0,0 +1,160 @@ +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/framework/op_segment.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session_options.h" + +#if defined(PLATFORM_GOOGLE) +DECLARE_bool(brain_gpu_use_bfc_allocator); +#else +extern bool FLAGS_brain_gpu_use_bfc_allocator; +#endif + +namespace tensorflow { +namespace test { + +Benchmark::Benchmark(const string& device, Graph* g, + const SessionOptions* options, Graph* init) { + RequireDefaultOps(); + + FLAGS_brain_gpu_use_bfc_allocator = true; + + SessionOptions default_options; + if (!options) { + options = &default_options; + } + + testing::StopTiming(); + string t = str_util::Uppercase(device); + device_ = + DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0"); + CHECK(device_) << "Could not create a " << device << " device"; + + pool_ = new thread::ThreadPool(options->env, "blocking", + port::NumSchedulableCPUs()); + + auto runner = [this](std::function closure) { + pool_->Schedule(closure); + }; + + rendez_ = NewLocalRendezvous(); + + if (init) { + Executor* init_exec; + TF_CHECK_OK(NewLocalExecutor( + { + device_, nullptr, false, + [this](const NodeDef& ndef, OpKernel** kernel) { + return CreateNonCachedKernel(device_, nullptr, ndef, kernel); + }, + [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); }, + }, + init, &init_exec)); + Executor::Args args; + args.rendezvous = rendez_; + args.runner = runner; + TF_CHECK_OK(init_exec->Run(args)); + delete init_exec; + } + + TF_CHECK_OK(NewLocalExecutor( + { + device_, + nullptr, + false, + [this](const NodeDef& ndef, OpKernel** kernel) { + return CreateNonCachedKernel(device_, nullptr, ndef, kernel); + }, + [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); }, + }, + g, &exec_)); +} + +Benchmark::~Benchmark() { + if (device_) { + rendez_->Unref(); + delete exec_; + delete device_; + delete pool_; + } +} + +void Benchmark::Run(int iters) { RunWithArgs({}, {}, iters); } + +string GetRendezvousKey(const Node* node) { + string send_device; + TF_CHECK_OK(GetNodeAttr(node->def(), "send_device", &send_device)); + string recv_device; + TF_CHECK_OK(GetNodeAttr(node->def(), "recv_device", &recv_device)); + string tensor_name; + TF_CHECK_OK(GetNodeAttr(node->def(), "tensor_name", &tensor_name)); + uint64 send_device_incarnation; + TF_CHECK_OK(GetNodeAttr(node->def(), "send_device_incarnation", + reinterpret_cast(&send_device_incarnation))); + return Rendezvous::CreateKey(send_device, send_device_incarnation, + recv_device, tensor_name, FrameAndIter(0, 0)); +} + +void Benchmark::RunWithArgs( + const std::vector>& inputs, + const std::vector& outputs, int iters) { + if (device_) { + // Gets inputs' and outputs' rendezvous keys. + std::vector> in; + for (const auto& p : inputs) { + in.push_back({GetRendezvousKey(p.first), p.second}); + } + std::vector out; + for (const auto& n : outputs) { + out.push_back(GetRendezvousKey(n)); + } + Tensor unused; // In benchmark, we don't care the return value. + bool is_dead; + + // Warm up + Executor::Args args; + args.rendezvous = rendez_; + args.runner = [this](std::function closure) { + pool_->Schedule(closure); + }; + for (int i = 0; i < 3; ++i) { + for (const auto& p : in) { + rendez_->Send(p.first, Rendezvous::Args(), p.second, false); + } + TF_CHECK_OK(exec_->Run(args)); + for (const string& key : out) { + rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead); + } + } + TF_CHECK_OK(device_->Sync()); + + testing::StartTiming(); + while (iters-- > 0) { + for (const auto& p : in) { + rendez_->Send(p.first, Rendezvous::Args(), p.second, false); + } + TF_CHECK_OK(exec_->Run(args)); + for (const string& key : out) { + rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead); + } + } + + TF_CHECK_OK(device_->Sync()); + testing::StopTiming(); + } +} + +} // end namespace test +} // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h new file mode 100644 index 0000000000..5ebe13e1d4 --- /dev/null +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h @@ -0,0 +1,52 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_ +#define TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_ + +#include +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +class Device; +class SessionOptions; + +namespace test { + +class Benchmark { + public: + // "device" must be either "cpu" or "gpu". Takes ownership of "g" + // and "init". + Benchmark(const string& device, Graph* g, + const SessionOptions* options = nullptr, Graph* init = nullptr); + ~Benchmark(); + + // Executes the graph for "iters" times. + void Run(int iters); + + // If "g" contains send/recv nodes, before each execution, we send + // inputs to the corresponding recv nodes in the graph, after each + // execution, we recv outputs from the corresponding send nodes in + // the graph. In the benchmark, we throw away values returned by the + // graph. + void RunWithArgs(const std::vector>& inputs, + const std::vector& outputs, int iters); + + private: + thread::ThreadPool* pool_ = nullptr; + thread::ThreadPool* non_blocking_pool_ = nullptr; + Device* device_ = nullptr; + Rendezvous* rendez_ = nullptr; + Executor* exec_ = nullptr; + + TF_DISALLOW_COPY_AND_ASSIGN(Benchmark); +}; + +} // end namespace test +} // end namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_ diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc new file mode 100644 index 0000000000..6a75346805 --- /dev/null +++ b/tensorflow/core/common_runtime/local_device.cc @@ -0,0 +1,51 @@ +#define EIGEN_USE_THREADS + +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/session_options.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +namespace { + +DeviceBase::CpuWorkerThreads eigen_worker_threads; +Eigen::ThreadPoolInterface* eigen_thread_pool = nullptr; +Eigen::ThreadPoolDevice* eigen_device = nullptr; + +static bool InitModule(const SessionOptions& options) { + int32 intra_op_parallelism_threads = + options.config.intra_op_parallelism_threads(); + if (intra_op_parallelism_threads == 0) { + intra_op_parallelism_threads = port::NumSchedulableCPUs(); + } + LOG(INFO) << "Local device intra op parallelism threads: " + << intra_op_parallelism_threads; + eigen_worker_threads.num_threads = intra_op_parallelism_threads; + eigen_worker_threads.workers = new thread::ThreadPool( + options.env, "Eigen", intra_op_parallelism_threads); + eigen_thread_pool = new EigenThreadPoolWrapper(eigen_worker_threads.workers); + eigen_device = new Eigen::ThreadPoolDevice(eigen_thread_pool, + eigen_worker_threads.num_threads); + return true; +} +} // end namespace + +// LocalDevice ---------------------------------------------------------------- + +LocalDevice::LocalDevice(const SessionOptions& options, + const DeviceAttributes& attributes, + Allocator* device_allocator) + : Device(options.env, attributes, device_allocator) { + // All ThreadPoolDevices in the process will use this single fixed + // sized threadpool for numerical computations. + static bool init = InitModule(options); + CHECK(init); // Avoids compiler warning that init is unused. + set_tensorflow_cpu_worker_threads(&eigen_worker_threads); + set_eigen_cpu_device(eigen_device); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h new file mode 100644 index 0000000000..fc4cfc2dfc --- /dev/null +++ b/tensorflow/core/common_runtime/local_device.h @@ -0,0 +1,27 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_ +#define TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/device_attributes.pb.h" + +namespace tensorflow { + +class SessionOptions; + +// This class is shared by ThreadPoolDevice and GPUDevice and +// initializes a shared Eigen compute device used by both. This +// should eventually be removed once we refactor ThreadPoolDevice and +// GPUDevice into more 'process-wide' abstractions. +class LocalDevice : public Device { + public: + LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes, + Allocator* device_allocator); + ~LocalDevice() override {} + + private: + TF_DISALLOW_COPY_AND_ASSIGN(LocalDevice); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_ diff --git a/tensorflow/core/common_runtime/local_session.cc b/tensorflow/core/common_runtime/local_session.cc new file mode 100644 index 0000000000..ab6993b8a2 --- /dev/null +++ b/tensorflow/core/common_runtime/local_session.cc @@ -0,0 +1,500 @@ +#include "tensorflow/core/common_runtime/local_session.h" + +#include +#include + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/common_runtime/session_factory.h" +#include "tensorflow/core/common_runtime/simple_placer.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_partition.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +namespace { + +thread::ThreadPool* kernel_thread_pool_ = nullptr; +static bool InitModule(const SessionOptions& options) { + int32 inter_op_parallelism_threads = + options.config.inter_op_parallelism_threads(); + if (inter_op_parallelism_threads == 0) { + // Default to using the number of cores available in the process. + inter_op_parallelism_threads = port::NumSchedulableCPUs(); + } + LOG(INFO) << "Local session inter op parallelism threads: " + << inter_op_parallelism_threads; + kernel_thread_pool_ = new thread::ThreadPool(options.env, "Compute", + inter_op_parallelism_threads); + return true; +} + +// TODO(vrv): Figure out how to unify the many different functions +// that generate RendezvousKey, since many of them have to be +// consistent with each other. +string GetRendezvousKey(const string& tensor_name, + const DeviceAttributes& device_info, + const FrameAndIter& frame_iter) { + return strings::StrCat(device_info.name(), ";", + strings::FpToString(device_info.incarnation()), ";", + device_info.name(), ";", tensor_name, ";", + frame_iter.frame_id, ":", frame_iter.iter_id); +} + +// NOTE: On Android with a single device, there is never +// a risk of an OpKernel blocking indefinitely: +// +// 1) No operations do I/O that depends on other simultaneous kernels, +// +// 2) Recv nodes always complete immediately: The inputs are sent into +// the local rendezvous before we start the executor, so the +// corresonding recvs will not block. +// +// Based on these assumptions, we can use the same thread pool for +// both "non-blocking" and "blocking" OpKernels on Android. +// +// This may change down the road when we add support for multiple +// devices that run concurrently, in which case we will need to +// revisit this decision. +void SchedClosure(std::function c) { +// TODO(sanjay): Get rid of __ANDROID__ path +#ifdef __ANDROID__ + // On Android, there is no implementation of ThreadPool that takes + // std::function, only Closure, which we cannot easily convert. + // + // Instead, we just run the function in-line, which is currently + // safe given the reasoning above. + c(); +#else + kernel_thread_pool_->Schedule(c); +#endif // __ANDROID__ +} + +} // namespace + +LocalSession::LocalSession(const SessionOptions& options, + const DeviceMgr* device_mgr) + : options_(options), + device_mgr_(device_mgr), + cancellation_manager_(new CancellationManager()) { + static bool init = InitModule(options); + CHECK(init); // Avoids compiler warning that init is unused. + session_handle_ = strings::FpToString(random::New64()); + int devices_added = 0; + if (options.config.log_device_placement()) { + const string mapping_str = device_mgr_->DeviceMappingString(); + printf("Device mapping:\n%s", mapping_str.c_str()); + LOG(INFO) << "Device mapping:\n" << mapping_str; + } + for (auto d : device_mgr_->ListDevices()) { + devices_.push_back(d); + device_set_.AddDevice(d); + d->op_segment()->AddHold(session_handle_); + + // The first device added is special: it is the 'client device' (a + // CPU device) from which we feed and fetch Tensors. + if (devices_added == 0) { + device_set_.set_client_device(d); + } + ++devices_added; + } +} + +LocalSession::~LocalSession() { + for (auto d : device_mgr_->ListDevices()) { + d->op_segment()->RemoveHold(session_handle_); + } + for (auto it : executors_) { + delete it.second; + } + delete cancellation_manager_; +} + +Status LocalSession::Create(const GraphDef& graph) { + mutex_lock l(graph_def_lock_); + if (graph_created_) { + return errors::AlreadyExists( + "A Graph has already been created for this session."); + } + return ExtendLocked(graph); +} + +Status LocalSession::Extend(const GraphDef& graph) { + mutex_lock l(graph_def_lock_); + return ExtendLocked(graph); +} + +Status LocalSession::ExtendLocked(const GraphDef& graph) { + graph_created_ = true; // In case this is first call + graph_def_.MergeFrom(graph); + return Status::OK(); +} + +Status LocalSession::Run(const std::vector>& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs) { + { + mutex_lock l(graph_def_lock_); + if (!graph_created_) { + return errors::InvalidArgument( + "Session was not created with a graph before Run()!"); + } + } + + // Extract the inputs names for this run of the session. + std::vector input_tensor_names; + input_tensor_names.reserve(inputs.size()); + for (const auto& it : inputs) { + input_tensor_names.push_back(it.first); + } + + // Check if we already have an executor for these arguments. + ExecutorsAndKeys* executors_and_keys; + Status s = GetOrCreateExecutors(input_tensor_names, output_names, + target_nodes, &executors_and_keys); + if (!s.ok()) { + return s; + } + + IntraProcessRendezvous* rendez = + new IntraProcessRendezvous(device_mgr_.get()); + core::ScopedUnref rendez_unref(rendez); + + // Insert the input tensors into the local rendezvous by their + // rendezvous key. + for (const auto& input : inputs) { + const string& input_key = executors_and_keys->input_keys[input.first]; + s = rendez->Send(input_key, Rendezvous::Args(), input.second, false); + if (!s.ok()) { + rendez->StartAbort(s); + return s; + } + } + + // Start parallel Executors. + Notification executors_done; + const int num_executors = executors_and_keys->device_executors.size(); + ExecutorBarrier* barrier = new ExecutorBarrier( + num_executors, rendez, [&executors_done, &s](const Status& ret) { + s = ret; + executors_done.Notify(); + }); + + Executor::Args args; + args.rendezvous = rendez; + args.cancellation_manager = cancellation_manager_; + args.runner = SchedClosure; + + for (auto device_executor : executors_and_keys->device_executors) { + Executor* exec = device_executor.second; + exec->RunAsync(args, barrier->Get()); + } + + executors_done.WaitForNotification(); + + TF_RETURN_IF_ERROR(s); + + if (!output_names.empty()) { + outputs->resize(output_names.size()); + } + + // Get the outputs from the rendezvous + for (size_t output_offset = 0; output_offset < output_names.size(); + ++output_offset) { + const string& output_key = + executors_and_keys->output_keys[output_names[output_offset]]; + Tensor output_tensor; + bool is_dead; + + // Fetch data from the Rendezvous. + s = rendez->Recv(output_key, Rendezvous::Args(), &output_tensor, &is_dead); + if (is_dead) { + s = errors::InvalidArgument("The tensor returned for ", + output_names[output_offset], + " was not valid."); + } + if (!s.ok()) { + rendez->StartAbort(s); + outputs->clear(); + return s; + } + + (*outputs)[output_offset] = output_tensor; + } + + return s; +} + +Status LocalSession::GetOrCreateExecutors( + gtl::ArraySlice inputs, gtl::ArraySlice outputs, + gtl::ArraySlice target_nodes, + ExecutorsAndKeys** executors_and_keys) { + // Sort the inputs and outputs, so we don't create separate + // executors when a user passes in the same inputs/outputs in + // different orders. + // + // We could consider some other signature instead of sorting that + // preserves the same property to avoid the sort in the future. + std::vector inputs_sorted(inputs.begin(), inputs.end()); + std::vector outputs_sorted(outputs.begin(), outputs.end()); + std::vector tn_sorted(target_nodes.begin(), target_nodes.end()); + std::sort(inputs_sorted.begin(), inputs_sorted.end()); + std::sort(outputs_sorted.begin(), outputs_sorted.end()); + std::sort(tn_sorted.begin(), tn_sorted.end()); + + const string key = strings::StrCat(str_util::Join(inputs_sorted, ","), "->", + str_util::Join(outputs_sorted, ","), "/", + str_util::Join(tn_sorted, ",")); + + // See if we already have the executors for this run. + { + mutex_lock l(executor_lock_); // could use reader lock + auto it = executors_.find(key); + if (it != executors_.end()) { + *executors_and_keys = it->second; + return Status::OK(); + } + } + + // The executor_lock_ is intentionally released while executor is + // being created. + std::unordered_map graphs; + Status s = CreateGraphs(inputs, outputs, target_nodes, &graphs); + if (!s.ok()) { + return s; + } + + bool has_control_flow = false; + for (const auto& graph : graphs) { + for (const Node* n : graph.second->nodes()) { + if (IsControlFlow(n)) { + has_control_flow = true; + break; + } + } + if (has_control_flow) break; + } + + std::unique_ptr ek(new ExecutorsAndKeys); + + for (const auto& graph : graphs) { + const string& partition_name = graph.first; + Graph* partition_graph = graph.second; + + Device* d; + s = device_mgr_->LookupDevice(partition_name, &d); + if (!s.ok()) { + return s; + } + + LocalExecutorParams params; + params.has_control_flow = has_control_flow; + params.device = d; + params.create_kernel = [this, d](const NodeDef& ndef, OpKernel** kernel) { + return CreateCachedKernel(d, session_handle_, nullptr, ndef, kernel); + }; + params.delete_kernel = [this, d](OpKernel* kernel) { + DeleteCachedKernel(d, session_handle_, kernel); + }; + + Executor* tmp_exec; + s = NewLocalExecutor(params, partition_graph, &tmp_exec); + if (!s.ok()) { + return s; + } + ek->device_executors.insert(std::make_pair(graph.first, tmp_exec)); + } + + // Compute the rendezvous keys to avoid recomputing them every time. + // + // We always use the first device as the device name portion of the + // key, even if we're feeding another graph. + for (const string& input : inputs) { + ek->input_keys[input] = GetRendezvousKey( + input, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); + } + for (const string& output : outputs) { + ek->output_keys[output] = GetRendezvousKey( + output, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); + } + + // Reacquire the lock, try to insert into the map. + mutex_lock l(executor_lock_); + const bool inserted = executors_.insert(std::make_pair(key, ek.get())).second; + if (!inserted) { + // Another thread created the entry before us, so delete the + // one we created and return the already created one. + auto it = executors_.find(key); + *executors_and_keys = it->second; + } else { + *executors_and_keys = ek.release(); + } + + return Status::OK(); +} + +void LocalSession::SaveStatefulNodes(Graph* graph) { + for (Node* n : graph->nodes()) { + if (n->op_def().is_stateful()) { + VLOG(2) << "Saving " << n->DebugString(); + stateful_placements_[n->name()] = n->assigned_device_name(); + } + } +} + +void LocalSession::RestoreStatefulNodes(Graph* graph) { + for (Node* n : graph->nodes()) { + if (n->op_def().is_stateful()) { + auto iter = stateful_placements_.find(n->name()); + if (iter != stateful_placements_.end()) { + n->set_assigned_device_name(iter->second); + VLOG(2) << "Restored " << n->DebugString(); + } + } + } +} + +Status LocalSession::CreateGraphs(gtl::ArraySlice feeds, + gtl::ArraySlice fetches, + gtl::ArraySlice target_nodes, + std::unordered_map* outputs) { + Graph graph(OpRegistry::Global()); + GraphConstructorOptions opts; + + { + mutex_lock l(graph_def_lock_); + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def_, &graph)); + } + + TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( + &graph, feeds, fetches, target_nodes, + device_set_.client_device()->attributes())); + + // Run the simple placer after rewriting the graph. + std::unordered_map node_name_to_cost_map; + for (Node* n : graph.nodes()) { + node_name_to_cost_map[n->name()] = n->cost_id(); + } + SimplePlacer placer(&graph, &device_set_, &node_name_to_cost_map, &options_); + + { + mutex_lock l(mu_); + // Restore stateful nodes. + RestoreStatefulNodes(&graph); + TF_RETURN_IF_ERROR(placer.Run()); + // Save stateful nodes. + SaveStatefulNodes(&graph); + } + + // Partition the graph across devices. + std::unordered_map partitions; + PartitionOptions popts; + popts.node_to_loc = [](const Node* node) { + return node->assigned_device_name(); + }; + popts.new_name = [this](const string& prefix) { + mutex_lock l(mu_); + return strings::StrCat(prefix, "/_", name_counter_++); + }; + popts.get_incarnation = [](const string& name) { + // The local session does not have changing incarnation numbers. + // Just return '1'. + return 1; + }; + popts.control_flow_added = false; + TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions)); + + std::vector device_names; + for (auto device : devices_) { + // Extract the LocalName from the device. + device_names.push_back(DeviceNameUtils::LocalName(device->name())); + } + + // Check for valid partitions. + for (const auto& partition : partitions) { + const string& local_partition_name = + DeviceNameUtils::LocalName(partition.first); + if (std::count(device_names.begin(), device_names.end(), + local_partition_name) == 0) { + return errors::InvalidArgument( + "Creating a partition for ", local_partition_name, + " which doesn't exist in the list of available devices. Available " + "devices: ", + str_util::Join(device_names, ",")); + } + } + + for (const auto& partition : partitions) { + const string& partition_name = partition.first; + + const GraphDef& graph_def = partition.second; + VLOG(2) << "Created " << graph_def.DebugString() << " for " + << partition_name; + + Graph* device_graph = new Graph(OpRegistry::Global()); + GraphConstructorOptions device_opts; + // There are internal operations (e.g., send/recv) that we now + // allow. + device_opts.allow_internal_ops = true; + device_opts.expect_device_spec = true; + Status s = + ConvertGraphDefToGraph(device_opts, graph_def, device_graph); + if (!s.ok()) { + delete device_graph; + // Also delete other graphs created during the loop. + gtl::STLDeleteValues(outputs); + return s; + } + outputs->insert(std::make_pair(partition_name, device_graph)); + } + + return Status::OK(); +} + +::tensorflow::Status LocalSession::Close() { + cancellation_manager_->StartCancel(); + return ::tensorflow::Status::OK(); +} + +class LocalSessionFactory : public SessionFactory { + public: + LocalSessionFactory() {} + + Session* NewSession(const SessionOptions& options) override { + std::vector devices; + DeviceFactory::AddDevices(options, "/job:localhost/replica:0/task:0", + &devices); + return new LocalSession(options, new DeviceMgr(devices)); + } +}; + +class LocalSessionRegistrar { + public: + LocalSessionRegistrar() { + SessionFactory::Register("LOCAL_SESSION", new LocalSessionFactory()); + } +}; +static LocalSessionRegistrar registrar; + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/local_session.h b/tensorflow/core/common_runtime/local_session.h new file mode 100644 index 0000000000..453cfdde47 --- /dev/null +++ b/tensorflow/core/common_runtime/local_session.h @@ -0,0 +1,109 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_ +#define TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class Device; + +class LocalSession : public Session { + public: + // Takes ownership of 'device_mgr'. + LocalSession(const SessionOptions& options, const DeviceMgr* device_mgr); + ~LocalSession() override; + + ::tensorflow::Status Create(const GraphDef& graph) override; + ::tensorflow::Status Extend(const GraphDef& graph) override; + ::tensorflow::Status Run(const std::vector>& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs) override; + ::tensorflow::Status Close() override; + + private: + struct ExecutorsAndKeys { + std::unordered_map device_executors; + std::unordered_map input_keys; + std::unordered_map output_keys; + + ~ExecutorsAndKeys() { + for (auto it : device_executors) { + delete it.second; + } + } + }; + + // Retrieves an already existing set of executors to run 'inputs' and + // 'outputs', or creates and caches them for future use. + ::tensorflow::Status GetOrCreateExecutors( + gtl::ArraySlice inputs, gtl::ArraySlice outputs, + gtl::ArraySlice target_nodes, + ExecutorsAndKeys** executors_and_keys); + + // Creates several graphs given the existing graph_def_ and the + // input feeds and fetches, given 'devices'. + ::tensorflow::Status CreateGraphs( + gtl::ArraySlice feeds, gtl::ArraySlice fetches, + gtl::ArraySlice target_nodes, + std::unordered_map* outputs); + + ::tensorflow::Status ExtendLocked(const GraphDef& graph) + EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_); + + const SessionOptions options_; + + // Device structures. + const std::unique_ptr device_mgr_; + std::vector devices_; // not owned + DeviceSet device_set_; + + string session_handle_; + bool graph_created_ GUARDED_BY(graph_def_lock_) = false; + + mutex graph_def_lock_; + GraphDef graph_def_ GUARDED_BY(graph_def_lock_); + + mutex executor_lock_; // protects executors_ + // Holds mappings from signature to the executors that process + // it. The reason for a level of indirection around mapped_type is + // to guarantee address stability. + std::unordered_map executors_ + GUARDED_BY(executor_lock_); + + CancellationManager* cancellation_manager_; + + // Saves and restores device placements for stateful nodes. + mutex mu_; + void SaveStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_); + void RestoreStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Map of placed stateful nodes, i.e. nodes for which is_stateful() + // is true, such as "params" and "queue" nodes. Once placed these + // nodes can not be moved to a different device. Maps node names to + // device names. + std::unordered_map stateful_placements_ GUARDED_BY(mu_); + + // For generating unique names. + int64 name_counter_ GUARDED_BY(mu_) = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(LocalSession); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_ diff --git a/tensorflow/core/common_runtime/local_session_test.cc b/tensorflow/core/common_runtime/local_session_test.cc new file mode 100644 index 0000000000..9325fe44c3 --- /dev/null +++ b/tensorflow/core/common_runtime/local_session_test.cc @@ -0,0 +1,314 @@ +#include "tensorflow/core/common_runtime/local_session.h" + +#include +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/device_name_utils.h" +#include + +namespace tensorflow { +namespace { + +Session* CreateSession() { + SessionOptions options; + (*options.config.mutable_device_count())["CPU"] = 2; + return NewSession(options); +} + +class LocalSessionMinusAXTest : public ::testing::Test { + public: + void Initialize(std::initializer_list a_values) { + RequireDefaultOps(); + Graph graph(OpRegistry::Global()); + + Tensor a_tensor(DT_FLOAT, TensorShape({2, 2})); + test::FillValues(&a_tensor, a_values); + Node* a = test::graph::Constant(&graph, a_tensor); + a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); + + Tensor x_tensor(DT_FLOAT, TensorShape({2, 1})); + test::FillValues(&x_tensor, {1, 1}); + Node* x = test::graph::Constant(&graph, x_tensor); + x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1"); + x_ = x->name(); + + // y = A * x + Node* y = test::graph::Matmul(&graph, a, x, false, false); + y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); + y_ = y->name(); + + Node* y_neg = test::graph::Unary(&graph, "Neg", y); + y_neg_ = y_neg->name(); + y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1"); + + test::graph::ToGraphDef(&graph, &def_); + } + + string x_; + string y_; + string y_neg_; + GraphDef def_; +}; + +TEST_F(LocalSessionMinusAXTest, RunSimpleNetwork) { + Initialize({3, 2, -1, 0}); + std::unique_ptr session(CreateSession()); + ASSERT_TRUE(session != nullptr); + ASSERT_OK(session->Create(def_)); + std::vector> inputs; + + // Request two targets: one fetch output and one non-fetched output. + std::vector output_names = {y_ + ":0"}; + std::vector target_nodes = {y_neg_}; + std::vector outputs; + Status s = session->Run(inputs, output_names, target_nodes, &outputs); + ASSERT_OK(s); + + ASSERT_EQ(1, outputs.size()); + // The first output should be initiailzed and have the correct + // output. + auto mat = outputs[0].matrix(); + ASSERT_TRUE(outputs[0].IsInitialized()); + EXPECT_FLOAT_EQ(5.0, mat(0, 0)); +} + +TEST_F(LocalSessionMinusAXTest, TestFeed) { + Initialize({1, 2, 3, 4}); + std::unique_ptr session(CreateSession()); + ASSERT_TRUE(session != nullptr); + + ASSERT_OK(session->Create(def_)); + + // Fill in the input and ask for the output + // + // Note that the input being fed is on the second device. + Tensor t(DT_FLOAT, TensorShape({2, 1})); + t.matrix()(0, 0) = 5; + t.matrix()(1, 0) = 6; + std::vector> inputs = {{x_, t}}; + std::vector output_names = {y_ + ":0"}; + std::vector outputs; + + // Run the graph + Status s = session->Run(inputs, output_names, {}, &outputs); + ASSERT_OK(s); + + ASSERT_EQ(1, outputs.size()); + auto mat = outputs[0].matrix(); + + // Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6 + EXPECT_FLOAT_EQ(17.0, mat(0, 0)); + EXPECT_FLOAT_EQ(39.0, mat(1, 0)); +} + +TEST_F(LocalSessionMinusAXTest, TestConcurrency) { + Initialize({1, 2, 3, 4}); + std::unique_ptr session(CreateSession()); + ASSERT_TRUE(session != nullptr); + ASSERT_OK(session->Create(def_)); + + // Fill in the input and ask for the output + thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4); + + // Run the graph 1000 times in 4 different threads concurrently. + std::vector output_names = {y_ + ":0"}; + auto fn = [&session, output_names]() { + for (int i = 0; i < 1000; ++i) { + std::vector> inputs; + std::vector outputs; + // Run the graph + Status s = session->Run(inputs, output_names, {}, &outputs); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(1, outputs.size()); + auto mat = outputs[0].matrix(); + EXPECT_FLOAT_EQ(3.0, mat(0, 0)); + } + }; + + for (int i = 0; i < 4; ++i) { + tp->Schedule(fn); + } + + // Wait for the functions to finish. + delete tp; +} + +TEST_F(LocalSessionMinusAXTest, TwoCreateCallsFails) { + Initialize({1, 2, 3, 4}); + std::unique_ptr session(CreateSession()); + ASSERT_TRUE(session != nullptr); + ASSERT_OK(session->Create(def_)); + + // Second is not. + ASSERT_FALSE(session->Create(def_).ok()); +} + +TEST_F(LocalSessionMinusAXTest, ForgetToCreate) { + Initialize({1, 2, 3, 4}); + std::unique_ptr session(CreateSession()); + ASSERT_TRUE(session != nullptr); + std::vector> inputs; + std::vector outputs; + ASSERT_FALSE(session->Run(inputs, {y_ + ":0"}, {y_neg_}, &outputs).ok()); +} + +TEST_F(LocalSessionMinusAXTest, InvalidDevice) { + GraphDef def; + Graph graph(OpRegistry::Global()); + + Tensor a_tensor(DT_FLOAT, TensorShape({2, 2})); + a_tensor.flat().setRandom(); + Node* a = test::graph::Constant(&graph, a_tensor); + a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); + Tensor x_tensor(DT_FLOAT, TensorShape({2, 1})); + x_tensor.flat().setRandom(); + Node* x = test::graph::Constant(&graph, x_tensor); + x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1"); + // Skip placing y. + Node* y = test::graph::Matmul(&graph, a, x, false, false); + y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:2"); + + test::graph::ToGraphDef(&graph, &def); + + std::unique_ptr session(CreateSession()); + ASSERT_TRUE(session != nullptr); + ASSERT_OK(session->Create(def)); + std::vector> inputs; + std::vector output_names = {y->name() + ":0"}; + std::vector outputs; + + // Should return an error. + ASSERT_FALSE(session->Run(inputs, output_names, {}, &outputs).ok()); + + // Fix placement and run again + def.Clear(); + y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1"); + test::graph::ToGraphDef(&graph, &def); + session.reset(CreateSession()); + ASSERT_OK(session->Create(def)); + ASSERT_OK(session->Run(inputs, output_names, {}, &outputs)); +} + +TEST(LocalSessionTest, KeepsStateAcrossRunsOfSession) { + GraphDef def; + Graph g(OpRegistry::Global()); + Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10})); + var->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); + + Tensor twenty(DT_FLOAT, TensorShape({10})); + for (int i = 0; i < 10; ++i) { + twenty.flat()(i) = 20.0; + } + + Node* twenty_node = test::graph::Constant(&g, twenty); + twenty_node->set_assigned_device_name( + "/job:localhost/replica:0/task:0/cpu:0"); + + Node* init = test::graph::Assign(&g, var, twenty_node); + init->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); + + test::graph::ToGraphDef(&g, &def); + + std::unique_ptr session(CreateSession()); + ASSERT_TRUE(session != nullptr); + ASSERT_OK(session->Create(def)); + + std::vector> inputs; + std::vector outputs; + + // Initialize the variable + Status s = session->Run(inputs, {init->name()}, {}, &outputs); + ASSERT_OK(s); + + // Get the variable's data + s = session->Run(inputs, {var->name() + ":0"}, {}, &outputs); + ASSERT_OK(s); + ASSERT_EQ(1, outputs.size()); + ASSERT_TRUE(outputs[0].IsInitialized()); + EXPECT_EQ(20.0, outputs[0].flat()(0)); +} + +TEST(LocalSessionTest, MultipleFeedTest) { + GraphDef def; + Graph g(OpRegistry::Global()); + Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10})); + var->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); + + Tensor first_value(DT_FLOAT, TensorShape({})); + first_value.scalar()() = 1.0; + Node* first_const = test::graph::Constant(&g, first_value); + Node* first_identity = test::graph::Identity(&g, first_const); + + Tensor second_value(DT_FLOAT, TensorShape({})); + second_value.scalar()() = 2.0; + Node* second_const = test::graph::Constant(&g, second_value); + Node* second_identity = test::graph::Identity(&g, second_const); + + test::graph::ToGraphDef(&g, &def); + + std::unique_ptr session(CreateSession()); + ASSERT_TRUE(session != nullptr); + ASSERT_OK(session->Create(def)); + + std::vector outputs; + + // Fetch without feeding. + Status s = session->Run( + {}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, + &outputs); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(2, outputs.size()); + ASSERT_EQ(1.0, outputs[0].flat()(0)); + ASSERT_EQ(2.0, outputs[1].flat()(0)); + + s = session->Run( + {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {}, + &outputs); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(2, outputs.size()); + ASSERT_EQ(2.0, outputs[0].flat()(0)); + ASSERT_EQ(1.0, outputs[1].flat()(0)); + + Tensor value_11(DT_FLOAT, TensorShape({})); + value_11.scalar()() = 11.0; + Tensor value_22(DT_FLOAT, TensorShape({})); + value_22.scalar()() = 22.0; + + // Feed [first_const, second_const] + s = session->Run( + {{first_const->name(), value_11}, {second_const->name(), value_22}}, + {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, + &outputs); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(2, outputs.size()); + ASSERT_EQ(11.0, outputs[0].flat()(0)); + ASSERT_EQ(22.0, outputs[1].flat()(0)); + + // Feed [second_const, first_const] + s = session->Run( + {{second_const->name(), value_22}, {first_const->name(), value_11}}, + {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, + &outputs); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(2, outputs.size()); + ASSERT_EQ(11.0, outputs[0].flat()(0)); + ASSERT_EQ(22.0, outputs[1].flat()(0)); +} + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc new file mode 100644 index 0000000000..111dea6d4c --- /dev/null +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -0,0 +1,170 @@ +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#if (!defined(PLATFORM_POSIX_ANDROID) && !defined(PLATFORM_GOOGLE_ANDROID)) && \ + (defined(PLATFORM_GOOGLE) || GOOGLE_CUDA) +#include "tensorflow/core/common_runtime/gpu/gpu_util.h" +#endif +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +namespace { + +void CopyTensorBetweenDevices(const string& id, DeviceContext* send_dev_context, + DeviceContext* recv_dev_context, Device* src, + Device* dst, + const AllocatorAttributes src_alloc_attr, + const AllocatorAttributes dst_alloc_attr, + const Tensor* input, Tensor* output, + std::function done) { + if (src->attributes().device_type() != dst->attributes().device_type()) { + done(errors::Unimplemented( + "Copy between device types not yet implemented: src=", src->name(), + " dst=", dst->name())); + } else if (src->attributes().device_type() != "CPU") { + done(errors::Unimplemented( + "Copy between non-CPU devices not yet implemented")); + } + *output = *input; + done(Status::OK()); +} + +#if (!defined(PLATFORM_POSIX_ANDROID) && !defined(PLATFORM_GOOGLE_ANDROID)) && \ + (defined(PLATFORM_GOOGLE) || GOOGLE_CUDA) +constexpr auto CopyTensorBetweenDevicesFunc = &GPUUtil::CopyViaDMA; +#else +constexpr auto CopyTensorBetweenDevicesFunc = &CopyTensorBetweenDevices; +#endif + +} // end namespace + +IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr) + : device_mgr_(device_mgr), local_(NewLocalRendezvous()) {} + +IntraProcessRendezvous::~IntraProcessRendezvous() { local_->Unref(); } + +Status IntraProcessRendezvous::Send(const string& key, + const Rendezvous::Args& args, + const Tensor& val, const bool is_dead) { + VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key; + { + mutex_lock l(mu_); + if (!status_.ok()) return status_; + } + Rendezvous::ParsedKey parsed; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed)); + + // Buffers "val" and "device_context" in local_. + return local_->Send(key, args, val, is_dead); +} + +Status IntraProcessRendezvous::ParseKey(const string& key, bool is_src, + Rendezvous::ParsedKey* parsed) { + { + mutex_lock l(mu_); + if (!status_.ok()) return status_; + } + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed)); + return Status::OK(); +} + +void IntraProcessRendezvous::SameWorkerRecvDone( + const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out, + StatusCallback done) { + // Do a quick copy (sharing the underlying buffer) if both tensors + // are on host memory. + const bool src_host = + (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU"); + const bool dst_host = + (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU"); + if (src_host && dst_host) { + *out = in; + done(Status::OK()); + return; + } + + // This copy must involve a non-CPU device. Hence, "in" must support DMA + // (e.g., string tensors do not work on GPU). + if (!DataTypeCanUseMemcpy(in.dtype())) { + done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()), + " tensor may not be copied from/to a GPU.")); + return; + } + + Device* src_device; + Status s = device_mgr_->LookupDevice(parsed.src_device, &src_device); + if (!s.ok()) { + done(s); + return; + } + Device* dst_device; + s = device_mgr_->LookupDevice(parsed.dst_device, &dst_device); + if (!s.ok()) { + done(s); + return; + } + + AllocatorAttributes attr = recv_args.alloc_attrs; + attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() || + recv_args.alloc_attrs.gpu_compatible()); + Allocator* out_allocator = dst_device->GetAllocator(attr); + Tensor copy(out_allocator, in.dtype(), in.shape()); + *out = copy; + + CopyTensorBetweenDevicesFunc(parsed.edge_name, send_args.device_context, + recv_args.device_context, src_device, dst_device, + send_args.alloc_attrs, recv_args.alloc_attrs, + &in, out, done); +} + +void IntraProcessRendezvous::RecvAsync(const string& key, + const Rendezvous::Args& recv_args, + DoneCallback done) { + VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key; + + Rendezvous::ParsedKey parsed; + Status s = ParseKey(key, false /*!is_src*/, &parsed); + if (!s.ok()) { + done(s, Args(), recv_args, Tensor(), false); + return; + } + + // Recv the tensor from local_. + local_->RecvAsync(key, recv_args, [this, parsed, done]( + const Status& status, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& in, bool is_dead) { + Status s = status; + Tensor* out = new Tensor; + StatusCallback final_callback = [done, send_args, recv_args, out, + is_dead](const Status& s) { + done(s, send_args, recv_args, *out, is_dead); + delete out; + }; + + if (s.ok()) { + SameWorkerRecvDone(parsed, send_args, recv_args, in, out, final_callback); + } else { + final_callback(s); + } + }); +} + +void IntraProcessRendezvous::StartAbort(const Status& s) { + CHECK(!s.ok()); + local_->StartAbort(s); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h new file mode 100644 index 0000000000..eaae65f956 --- /dev/null +++ b/tensorflow/core/common_runtime/rendezvous_mgr.h @@ -0,0 +1,73 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_ +#define TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +// IntraProcessRendezvous is a Rendezvous which expects all producers +// and consumers to be devices immediately accessible within the +// process. That is, it will never be necessary to perform an RPC to +// communicate with either. +// +// Buffering of Tensor values is delegated to a "local" Rendezvous +// obtained from NewLocalRendezvous(). This class just adds +// functionality to coordinate multiple process-local devices. +class IntraProcessRendezvous : public Rendezvous { + public: + explicit IntraProcessRendezvous(const DeviceMgr* device_mgr); + + // Forwards to local_, where the Tensor "val" will be buffered and + // any waiting callback stored. + Status Send(const string& key, const Rendezvous::Args& args, + const Tensor& val, const bool is_dead) override; + + // This method is called only by the RecvOp. It tests to see + // whether the value will be produced by a local or remote device + // and handles accordingly. In the local case it forwards to + // local_, in the remote case it initiates an RPC request. + void RecvAsync(const string& key, const Rendezvous::Args& args, + DoneCallback done) override; + + void StartAbort(const Status& status) override; + + private: + const DeviceMgr* device_mgr_; + Rendezvous* local_; // Owns a Ref on this object. + + mutable mutex mu_; + + // Status given by StartAbort() if any. + Status status_ GUARDED_BY(mu_); + + ~IntraProcessRendezvous() override; + + // Parses "key" into "parsed". If "is_src" is true, checks that the + // rendezvous key's source is in this process. If "is_src" is false, + // checks that the rendezvous key's destination is in this process. + Status ParseKey(const string& key, bool is_src, + Rendezvous::ParsedKey* parsed); + + // Callback handling the case when a rendezvous has been + // accomplished in local_ and the consumer is local to this process. + // Tensor "in" will be copied into "out". The key "parsed" encodes + // the src and dst devices. + typedef std::function StatusCallback; + void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& in, + Tensor* out, StatusCallback done); + + TF_DISALLOW_COPY_AND_ASSIGN(IntraProcessRendezvous); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_ diff --git a/tensorflow/core/common_runtime/session.cc b/tensorflow/core/common_runtime/session.cc new file mode 100644 index 0000000000..6d1ab5cea4 --- /dev/null +++ b/tensorflow/core/common_runtime/session.cc @@ -0,0 +1,51 @@ +#include + +#include "tensorflow/core/common_runtime/session_factory.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { + +namespace { +Status GetFactory(const SessionOptions& options, SessionFactory** ret) { + string runtime_type = "LOCAL_SESSION"; + if (!options.target.empty()) { + // Use the service based session. + runtime_type = "REMOTE_SESSION"; + } + *ret = SessionFactory::GetFactory(runtime_type); + if (!*ret) { + return errors::NotFound("Could not find session factory for ", + runtime_type); + } + return Status::OK(); +} +} // end namespace + +Session* NewSession(const SessionOptions& options) { + SessionFactory* factory; + Status s = GetFactory(options, &factory); + if (!s.ok()) { + LOG(ERROR) << s; + return nullptr; + } + return factory->NewSession(options); +} + +Status NewSession(const SessionOptions& options, Session** out_session) { + SessionFactory* factory; + Status s = GetFactory(options, &factory); + if (!s.ok()) { + *out_session = nullptr; + LOG(ERROR) << s; + return s; + } + *out_session = factory->NewSession(options); + if (!*out_session) { + return errors::Internal("Failed to create session."); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/session_factory.cc b/tensorflow/core/common_runtime/session_factory.cc new file mode 100644 index 0000000000..666b99812d --- /dev/null +++ b/tensorflow/core/common_runtime/session_factory.cc @@ -0,0 +1,41 @@ +#include "tensorflow/core/common_runtime/session_factory.h" + +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +namespace tensorflow { +namespace { + +static mutex* get_session_factory_lock() { + static mutex session_factory_lock; + return &session_factory_lock; +} + +typedef std::unordered_map SessionFactories; +SessionFactories* session_factories() { + static SessionFactories* factories = new SessionFactories; + return factories; +} + +} // namespace + +void SessionFactory::Register(const string& runtime_type, + SessionFactory* factory) { + mutex_lock l(*get_session_factory_lock()); + if (!session_factories()->insert({runtime_type, factory}).second) { + LOG(ERROR) << "Two session factories are being registered " + << "under" << runtime_type; + } +} + +SessionFactory* SessionFactory::GetFactory(const string& runtime_type) { + mutex_lock l(*get_session_factory_lock()); // could use reader lock + auto it = session_factories()->find(runtime_type); + if (it == session_factories()->end()) { + return nullptr; + } + return it->second; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/session_factory.h b/tensorflow/core/common_runtime/session_factory.h new file mode 100644 index 0000000000..f770ba93ff --- /dev/null +++ b/tensorflow/core/common_runtime/session_factory.h @@ -0,0 +1,25 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_ +#define TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_ + +#include + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class Session; +class SessionOptions; + +class SessionFactory { + public: + virtual Session* NewSession(const SessionOptions& options) = 0; + virtual ~SessionFactory() {} + static void Register(const string& runtime_type, SessionFactory* factory); + static SessionFactory* GetFactory(const string& runtime_type); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_ diff --git a/tensorflow/core/common_runtime/session_options.cc b/tensorflow/core/common_runtime/session_options.cc new file mode 100644 index 0000000000..ef585efb5c --- /dev/null +++ b/tensorflow/core/common_runtime/session_options.cc @@ -0,0 +1,9 @@ +#include "tensorflow/core/public/session_options.h" + +#include "tensorflow/core/public/env.h" + +namespace tensorflow { + +SessionOptions::SessionOptions() : env(Env::Default()) {} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/session_test.cc b/tensorflow/core/common_runtime/session_test.cc new file mode 100644 index 0000000000..82b5d7ffb0 --- /dev/null +++ b/tensorflow/core/common_runtime/session_test.cc @@ -0,0 +1,17 @@ +#include "tensorflow/core/public/session.h" + +#include "tensorflow/core/public/session_options.h" +#include + +namespace tensorflow { +namespace { + +TEST(SessionTest, InvalidTargetReturnsNull) { + SessionOptions options; + options.target = "invalid target"; + + EXPECT_EQ(nullptr, tensorflow::NewSession(options)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc new file mode 100644 index 0000000000..1cd1db29db --- /dev/null +++ b/tensorflow/core/common_runtime/simple_placer.cc @@ -0,0 +1,559 @@ +#include "tensorflow/core/common_runtime/simple_placer.h" + +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +namespace { + +// Returns a list of devices sorted by name from 'devices' whose type is in +// 'supported_device_types'. This function searches in order of the device +// types in 'supported_device_types' and returns the *first* subset of devices +// that match. +// +// For example, if suported_device_types contains {GPU, CPU} and +// 'devices' contains CPU and GPU devices, the returned vector will +// include *only* GPU devices, since that is higher in the priority +// order in 'supported_device_types'. +std::vector FilterSupportedDevices( + const std::vector& devices, + const DeviceTypeVector& supported_device_types) { + std::vector filtered_devices; + auto device_sort = [](const Device* a, const Device* b) { + return a->name() < b->name(); + }; + for (DeviceType d : supported_device_types) { + for (Device* device : devices) { + if (DeviceType(device->attributes().device_type()) == d) { + filtered_devices.emplace_back(device); + } + } + + // If there are any devices under this device type, return this + // subset. + if (!filtered_devices.empty()) { + std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort); + return filtered_devices; + } + } + + std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort); + return filtered_devices; +} + +bool HasColocatedNodeName(const Node& node) { + return StringPiece(node.def().device()).starts_with("@"); +} + +Status ParseColocatedNodeName(const Node& node, + string* out_colocated_node_name) { + StringPiece device(node.def().device()); + if (!device.Consume("@")) { + return errors::InvalidArgument("Malformed colocated node name: '", device, + "'"); + } + // TODO(mrry): Validate that the node name is a valid node name. + *out_colocated_node_name = device.ToString(); + return Status::OK(); +} + +// This class maintains the connected components of a colocation +// constraint graph, and uses this information to assign a satisfying +// device placement to the nodes of the graph. +// +// The typical usage pattern is: +// +// Graph graph = ...; +// DeviceSet device_set = ...; +// ColocationGraph colocation_graph(graph, device_set); +// +// // Add all the nodes of graph to colocation_graph. +// for (Node* node : graph.nodes()) { +// TF_RETURN_IF_ERROR(colocation_graph.AddNode(*node)); +// } +// +// // Add one or more colocation constraint. +// Node node_1 = *graph.FindNodeId(...); +// Node node_2 = *graph.FindNodeId(...); +// TF_RETURN_IF_ERROR(colocation_graph.ColocateNodes(node_1, node_2)); +// +// // Assign devices based on the accumulated constraints. +// for (Node* node : graph.nodes()) { +// TF_RETURN_IF_ERROR(colocation_graph.AssignDevice(node)); +// } +// +// The implementation uses the union-find algorithm to maintain the +// connected components efficiently and incrementally as edges +// (implied by ColocationGraph::ColocateNodes() invocations) are added. +class ColocationGraph { + public: + ColocationGraph(Graph* graph, const DeviceSet* device_set, + const SessionOptions* options) + : device_set_(device_set), + device_types_(device_set->PrioritizedDeviceTypeList()), + options_(options) { + members_.reserve(graph->num_node_ids()); + } + + // Adds the given node to this ColocationGraph as a singleton. + // + // NOTE: The implementation assumes that the ids of nodes passed to + // this method are dense and zero-based; the memory used will be linear in + // the largest node ID. + // NOTE: If this method returns an error, *this is left in an undefined + // state. + Status AddNode(const Node& node) { + Member member; + TF_RETURN_IF_ERROR(InitializeMember(node, &member)); + CHECK_GE(member.parent, 0); + members_.resize(member.parent + 1); + members_[member.parent] = std::move(member); + return Status::OK(); + } + + // Merge the (possibly disjoint) sets containing nodes "x" and + // "y". Returns OK if the all nodes in the union of these sets can + // be placed on the same device type. + // + // NOTE: If this method returns an error, *this is left in an undefined + // state. + Status ColocateNodes(const Node& x, const Node& y) { + int x_root = FindRoot(x.id()); + int y_root = FindRoot(y.id()); + if (x_root != y_root) { + // Merge the sets by swinging the parent pointer of the smaller + // tree to point to the root of the larger tree. Together with + // path compression in ColocationGraph::FindRoot, this ensures + // that we do not experience pathological performance on graphs + // such as chains. + int new_root, old_root; + if (members_[x_root].rank < members_[y_root].rank) { + // The tree rooted at x_root is shallower, so connect it to + // y_root. The rank of y_root is unchanged because its new + // child has strictly less rank. + members_[x_root].parent = y_root; + new_root = y_root; + old_root = x_root; + } else if (members_[x_root].rank > members_[y_root].rank) { + // The tree rooted at y_root is shallower, so connect it to + // x_root. The rank of x_root is unchanged because its new + // child has strictly less rank. + members_[y_root].parent = x_root; + new_root = x_root; + old_root = y_root; + } else { + // Both trees have the same rank, so break the tie by choosing + // x_root as the new root. + members_[y_root].parent = x_root; + // Increment the rank of the tree rooted at x_root, because it + // is now strictly deeper than before. + ++members_[x_root].rank; + new_root = x_root; + old_root = y_root; + } + + // Merge the partial device specifications, and ensure that they are + // compatible. NULL options_ is treated as allowing soft placement. + // TODO(mrry): Consider enriching the error message by pointing + // out which nodes have the explicit partial device + // specifications that caused this conflict. + TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames( + &members_[new_root].device_name, members_[old_root].device_name, + options_ == nullptr || options_->config.allow_soft_placement())); + + // Ensure that the common root has at least one supported device + // type, by computing the intersection of + // members_[new_root].supported_device_types and + // members_[old_root].supported_device_types. + MergeSupportedDevices(&members_[new_root].supported_device_types, + members_[old_root].supported_device_types); + if (members_[x_root].supported_device_types.size() == 0) { + return errors::InvalidArgument( + "Cannot colocate nodes '", x.name(), "' and '", y.name(), + "' because no device type supports both of those nodes and the " + "other nodes colocated with them"); + } + } + return Status::OK(); + } + + // For the given node, subject to the constraints previously given + // to this ColocationGraph, set its assigned_device_name. Returns OK + // if a satisfying device can be found, otherwise an error. + Status AssignDevice(Node* node) { + int node_root = FindRoot(node->id()); + if (members_[node_root].assigned_device == nullptr) { + // We have not yet assigned a device for the colocated node set containing + // n, so we do so now using the constraints on the root node. + + // "devices" will contain the set of feasible placements for the + // colocated node set containing n. + std::vector devices; + if (DeviceNameUtils::HasSomeDetails(members_[node_root].device_name)) { + // The root node has a (possibly partial) device + // specification, so enumerate the physical devices that + // conform to it. + device_set_->FindMatchingDevices(members_[node_root].device_name, + &devices); + + if (!devices.empty()) { + // Filter devices into those that are compatible with the root + // node (and its children). + devices = FilterSupportedDevices( + devices, members_[node_root].supported_device_types); + } + + // Perform soft placement if allow_soft_placement is set. options_ + // being NULL is treated as allowing soft placement. + if (devices.empty() && + (options_ == nullptr || options_->config.allow_soft_placement())) { + // The soft_device_name is the same as the node's device name + // without specifying the device type or ID. + DeviceNameUtils::ParsedName soft_device_name = + members_[node_root].device_name; + soft_device_name.type.clear(); + soft_device_name.has_type = false; + soft_device_name.has_id = false; + device_set_->FindMatchingDevices(soft_device_name, &devices); + if (!devices.empty()) { + devices = FilterSupportedDevices( + devices, members_[node_root].supported_device_types); + } + } + + if (devices.empty()) { + // Return an error when a physical device that matches an explicit + // device specification is not found. This ensures that we don't + // assign a node to GPU when the user wanted to force it on CPU. + DeviceNameUtils::ParsedName specified_device_name; + if (DeviceNameUtils::ParseFullName(node->def().device(), + &specified_device_name) && + specified_device_name == members_[node_root].device_name) { + // The specified device and merged set device match, and + // will appear in the GraphDef (for debugging), so just + // print the specified device. + return errors::InvalidArgument( + "Could not satisfy explicit device specification '", + node->def().device(), "'"); + } else { + // The specified device may be a valid device but the + // merged set device is different, so print both. + return errors::InvalidArgument( + "Could not satisfy explicit device specification '", + node->def().device(), + "' because the node was colocated with a group of nodes that " + "required incompatible device '", + DeviceNameUtils::ParsedNameToString( + members_[node_root].device_name), + "'"); + } + } + } else { + // The device is completely unspecified, so enumerate the devices that + // support all of the nodes in the set. + if (device_set_->devices().empty()) { + return errors::Internal("No devices are registered"); + } + devices = FilterSupportedDevices( + device_set_->devices(), members_[node_root].supported_device_types); + + if (devices.empty()) { + return errors::InvalidArgument( + "Node had no OpKernel registered to support this operation: ", + "Operation was ", node->type_string(), " and inputs were ", + DataTypeVectorString(node->input_types())); + } + } + + // Returns the first device in sorted devices list so we will always + // choose the same device. + members_[node_root].assigned_device = devices[0]; + } + node->set_assigned_device_name(members_[node_root].assigned_device->name()); + + // Log placement if log_device_placement is set. + if (options_ && options_->config.log_device_placement()) { + printf("%s: %s\n", node->name().c_str(), + node->assigned_device_name().c_str()); + LOG(INFO) << node->name() << ": " << node->assigned_device_name(); + } + + return Status::OK(); + } + + private: + // Represents a node in the disjoint node set forest, and the + // accumulated constraints on the device used by that node. + struct Member { + Member() = default; + // The id of the node that is the parent of this one, or its own + // id if it is a root. parent <= 0 indicates that this member is invalid. + int parent = -1; + // A proxy for the depth of the tree that is used to prefer + // connecting smaller trees to larger trees when merging disjoint + // sets. + int rank = 0; + // The intersection of all device types supported by this node, + // and those of all of its children, in priority order + // of the preferred device. + DeviceTypeVector supported_device_types; + // The merged form of the device requested for this node, with + // those of all of its children. + DeviceNameUtils::ParsedName device_name; + // If this node is a root, stores the Device to which this node + // and all of its children have been assigned, or nullptr if this + // has not yet been computed by GetAssignedDevice(). + Device* assigned_device = nullptr; + }; + + Status InitializeMember(const Node& node, Member* member) { + const int id = node.id(); + if (id < 0) { + return errors::InvalidArgument("Node id was not positive: ", id); + } + member->parent = id; + TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode( + device_types_, node.def(), &member->supported_device_types)); + + if (!node.assigned_device_name().empty()) { + // This node has already been assigned to a device, so we + // respect this placement, after sanity-checking it. The + // device_name and supported_device_types for this node reflect + // the assigned device, so any nodes colocated with this node + // will be assigned to the same device (assuming this is + // possible). + // NOTE: Since any assignment must have been performed by + // the TensorFlow runtime, we consider errors in this branch to + // be INTERNAL. + if (!DeviceNameUtils::ParseFullName(node.assigned_device_name(), + &member->device_name)) { + return errors::Internal("Malformed assigned device '", + node.assigned_device_name(), "'"); + } + std::vector devices; + const Device* assigned_device = + device_set_->FindDeviceByName(node.assigned_device_name()); + if (assigned_device == nullptr) { + return errors::Internal("Assigned device '", + node.assigned_device_name(), + "' does not match any device"); + } + + for (DeviceType d : member->supported_device_types) { + if (DeviceType(assigned_device->attributes().device_type()) == d) { + return Status::OK(); + } + } + + return errors::Internal("Assigned device '", node.assigned_device_name(), + "' does not have registered OpKernel support " + "for ", + node.def().op()); + } else { + // This node has not yet been assigned to a device, so we + // calculate any constraints due to the set of registered + // kernels and any (partial) user-provided device specification + // in the NodeDef. + + // If no kernels are registered for this op type, fail with an error. + if (member->supported_device_types.empty()) { + return errors::InvalidArgument( + "No OpKernel was registered to support " + "Op '", + node.def().op(), "' with these attrs"); + } + + // If the NodeDef contains a device that is *not* a colocated node name + // (i.e. it does not begin with '@') then we interpret it as a (partial) + // device specification. + string colocated_node_name; + if (!node.def().device().empty() && !HasColocatedNodeName(node)) { + // The user has specified a device in the NodeDef, try to find a + // valid device matching their specification in the set of + // devices. + // NOTE: The full name may specify a device that is not in + // n.supported_device_types(), but we check that in AssignDevice(). + if (!DeviceNameUtils::ParseFullName(node.def().device(), + &member->device_name)) { + return errors::InvalidArgument("Malformed device specification '", + node.def().device(), "'"); + } + } + } + return Status::OK(); + } + + // Updates target to contain the intersection of the device types in + // "target" and "other". + static void MergeSupportedDevices(DeviceTypeVector* target, + const DeviceTypeVector& other) { + DeviceTypeVector temp = *target; + target->clear(); + + // Iterate in priority order. + for (DeviceType device_type : temp) { + bool found = false; + for (DeviceType other_device_type : other) { + if (device_type == other_device_type) { + found = true; + break; + } + } + if (found) { + target->push_back(device_type); + } + } + } + + // Returns the root node of the disjoint tree to which the node with the + // given id is connected. + int FindRoot(int node_id) { + DCHECK_GE(members_[node_id].parent, 0); + if (members_[node_id].parent != node_id) { + // NOTE: Compress paths from node_id to its root, so that future + // calls to FindRoot and ColocateNodes are more efficient. + members_[node_id].parent = FindRoot(members_[node_id].parent); + } + return members_[node_id].parent; + } + + std::vector members_; + const DeviceSet* device_set_; // Not owned. + const std::vector device_types_; + const SessionOptions* options_; // Not owned; +}; + +} // namespace + +SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices, + const NodeNameToIdMap* name_to_id_map, + const SessionOptions* options) + : graph_(graph), + devices_(devices), + name_to_id_map_(name_to_id_map), + options_(options) {} + +SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices, + const NodeNameToIdMap* name_to_id_map) + : graph_(graph), devices_(devices), name_to_id_map_(name_to_id_map) { + options_ = nullptr; +} + +SimplePlacer::~SimplePlacer() {} + +Status SimplePlacer::Run() { + if (devices_->devices().empty()) { + return errors::FailedPrecondition("No devices are registered"); + } + + ColocationGraph colocation_graph(graph_, devices_, options_); + Status status; + + // 1. First add all of the nodes. Note that steps (1) and (2) + // requires two passes over the nodes because the graph (and hence + // the constraints) may not be acyclic. + for (Node* node : graph_->nodes()) { + // Skip the source and sink nodes. + if (!node->IsOp()) { + continue; + } + status = colocation_graph.AddNode(*node); + if (!status.ok()) return AttachDef(status, node->def()); + } + + // 2. Enumerate the constraint edges, and use them to update the disjoint + // node set. + for (Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + + // 2(a). If node n specifies a colocation constraint as its device name, + // add an edge from the colocated node to n. + if (HasColocatedNodeName(*node)) { + string colocated_node_name; + status = ParseColocatedNodeName(*node, &colocated_node_name); + if (!status.ok()) { + return AttachDef(status, node->def()); + } + Node* colocated_node; + status = GetNodeByName(colocated_node_name, &colocated_node); + if (!status.ok()) { + return AttachDef( + errors::InvalidArgument("Colocated node named in device '", + colocated_node_name, "' does not exist"), + node->def()); + } + status = colocation_graph.ColocateNodes(*colocated_node, *node); + if (!status.ok()) { + return AttachDef( + errors::InvalidArgument( + "Cannot satisfy colocation constraint named in device '", + colocated_node_name, "': ", status.error_message()), + node->def()); + } + } + + // 2(b). If `node` has an input edge with reference type, add an + // edge from the source of that edge to `node`. + for (const auto& edge : node->in_edges()) { + if (!edge->IsControlEdge() && + IsRefType(node->input_type(edge->dst_input()))) { + status = colocation_graph.ColocateNodes(*edge->src(), *node); + if (!status.ok()) { + return AttachDef( + errors::InvalidArgument("Cannot satisfy colocation constraint " + "implied by reference connection: ", + status.error_message()), + node->def()); + } + } + } + } + + // 3. For each node, assign a device based on the constraints in the + // disjoint node set. + for (Node* node : graph_->nodes()) { + // Skip the source and sink nodes. + if (!node->IsOp()) { + continue; + } + // Skip nodes that already have an assigned name. + if (!node->assigned_device_name().empty()) { + continue; + } + + status = colocation_graph.AssignDevice(node); + if (!status.ok()) { + return AttachDef( + errors::InvalidArgument("Cannot assign a device to node '", + node->name(), "': ", status.error_message()), + node->def()); + } + } + return Status::OK(); +} + +Status SimplePlacer::GetNodeByName(const string& name, Node** out_node) const { + NodeNameToIdMap::const_iterator iter = name_to_id_map_->find(name); + if (iter != name_to_id_map_->end()) { + *out_node = graph_->FindNodeId(iter->second); + if (*out_node) { + return Status::OK(); + } + } + return errors::NotFound(name); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/simple_placer.h b/tensorflow/core/common_runtime/simple_placer.h new file mode 100644 index 0000000000..4b3df50c72 --- /dev/null +++ b/tensorflow/core/common_runtime/simple_placer.h @@ -0,0 +1,81 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_ +#define TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +// A placement algorithm that assigns the nodes of the given Graph to +// devices the given DeviceSet, respecting the following constraints: +// +// 1. Existing device assignments remain unchanged. +// 2. Requested (partial or complete) device specifications in the +// are granted. +// 3. Nodes connected by edges of a reference type are colocated on +// the same device. +// 4. Given nodes "A" and "B", if node "B" has the device specification +// "@A", nodes "A" and "B" will be colocated on the same device. +// +// The implementation builds a constraint graph with the same set of +// nodes, and edges that represent colocation constraints between +// nodes. Each connected component in the resulting constraint graph +// is then assigned to a single device. +// +// TODO(mrry): "Soft" constraints, such as "place node 'x' as close as +// possible to node 'y' while respecting the other constraints"? +// TODO(mrry): Create a common interface for this and the other +// placement algorithms so that they may be injected into the graph +// builder. +class SimplePlacer { + public: + // A map from graph node names to numerical IDs (in a Graph object). + typedef std::unordered_map NodeNameToIdMap; + + // Creates an instance of the SimplePlacer algorithm for the given + // Graph "graph" (nodes in which may or may not be assigned) on the + // given DeviceSet "devices". The "name_to_id_map" maps the names of + // nodes in "g" to their numerical ID. + // + // REQUIRES: for all mappings (k, v) in "name_to_id_map", + // graph.FindNodeId(v)->name() == k. + // + // The "graph", "devices", and "name_to_id_map" pointer arguments + // are borrowed by this SimplePlacer, and must outlive it. + SimplePlacer(Graph* graph, const DeviceSet* devices, + const NodeNameToIdMap* name_to_id_map, + const SessionOptions* options); + + SimplePlacer(Graph* graph, const DeviceSet* devices, + const NodeNameToIdMap* name_to_id_map); + + ~SimplePlacer(); + + // Assigns each node in this SimplePlacer's graph to a device in its + // set of devices. + // + // This method is not thread-safe. + // Run() may be invoked at most once. + Status Run(); + + private: + Status GetNodeByName(const string& name, Node** out_node) const; + + Graph* const graph_; // Not owned. + const DeviceSet* const devices_; // Not owned. + const NodeNameToIdMap* const name_to_id_map_; // Not owned. + const SessionOptions* options_; // Not owned. + + TF_DISALLOW_COPY_AND_ASSIGN(SimplePlacer); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_ diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc new file mode 100644 index 0000000000..3139962d7e --- /dev/null +++ b/tensorflow/core/common_runtime/simple_placer_test.cc @@ -0,0 +1,863 @@ +#include "tensorflow/core/common_runtime/simple_placer.h" + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include + +namespace tensorflow { + +namespace { + +//////////////////////////////////////////////////////////////////////////////// +// +// Op, kernel, and device registrations to set up the environment. +// +// The SimplePlacer uses information about the op (input types), +// kernel (device constraints), and available devices to make +// placement decisions. To avoid depending on the full runtime, we +// define dummy implementations of these, and register them with the +// runtime. +// +//////////////////////////////////////////////////////////////////////////////// + +// A dummy OpKernel that is used to register ops on different devices. +class DummyOp : public OpKernel { + public: + explicit DummyOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +// A fake device that has specific device attributes, used to simulate +// the presence of a CPU or a GPU (without depending on that part of +// the runtime. +class FakeDevice : public Device { + private: + explicit FakeDevice(const DeviceAttributes& device_attributes) + : Device(nullptr, device_attributes, nullptr) {} + + public: + Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } + + Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } + + static std::unique_ptr MakeCPU(const string& name) { + DeviceAttributes device_attributes; + device_attributes.set_name(name); + device_attributes.set_device_type(DeviceType(DEVICE_CPU).type()); + return std::unique_ptr(new FakeDevice(device_attributes)); + } + + static std::unique_ptr MakeGPU(const string& name) { + DeviceAttributes device_attributes; + device_attributes.set_name(name); + device_attributes.set_device_type(DeviceType(DEVICE_GPU).type()); + return std::unique_ptr(new FakeDevice(device_attributes)); + } +}; + +// Register the following ops so they can be added to a Graph, and +// kernels so that they can be placed on particular device types. +REGISTER_OP("TestVariable").Output("o: Ref(float)"); +REGISTER_KERNEL_BUILDER(Name("TestVariable").Device(DEVICE_CPU), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestVariable").Device(DEVICE_GPU), DummyOp); + +REGISTER_OP("VariableCPU").Output("o: Ref(float)"); +REGISTER_KERNEL_BUILDER(Name("VariableCPU").Device(DEVICE_CPU), DummyOp); + +REGISTER_OP("VariableGPU").Output("o: Ref(float)"); +REGISTER_KERNEL_BUILDER(Name("VariableGPU").Device(DEVICE_GPU), DummyOp); + +REGISTER_OP("VariableNoKernels").Output("o: Ref(float)"); + +REGISTER_OP("TestAdd").Input("a: float").Input("b: float").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("TestAdd").Device(DEVICE_CPU), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestAdd").Device(DEVICE_GPU), DummyOp); + +REGISTER_OP("TestRelu").Input("i: float").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("TestRelu").Device(DEVICE_CPU), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestRelu").Device(DEVICE_GPU), DummyOp); + +REGISTER_OP("ReluGPU").Input("i: float").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("ReluGPU").Device(DEVICE_GPU), DummyOp); + +REGISTER_OP("TestAssign").Input("i: Ref(float)").Input("v: float"); +REGISTER_KERNEL_BUILDER(Name("TestAssign").Device(DEVICE_CPU), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestAssign").Device(DEVICE_GPU), DummyOp); + +REGISTER_OP("AssignCPU").Input("i: Ref(float)").Input("v: float"); +REGISTER_KERNEL_BUILDER(Name("AssignCPU").Device(DEVICE_CPU), DummyOp); + +REGISTER_OP("AssignGPU").Input("i: Ref(float)").Input("v: float"); +REGISTER_KERNEL_BUILDER(Name("AssignGPU").Device(DEVICE_GPU), DummyOp); + +REGISTER_OP("TestInput").Output("a: float").Output("b: float"); +REGISTER_KERNEL_BUILDER(Name("TestInput").Device(DEVICE_CPU), DummyOp); + +REGISTER_OP("TestDevice").Output("a: float").Output("b: float"); +REGISTER_KERNEL_BUILDER(Name("TestDevice").Device(DEVICE_GPU), DummyOp); + +REGISTER_OP("TestDeviceEnforce").Input("a: Ref(float)").Output("b: float"); +REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device(DEVICE_CPU), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device(DEVICE_GPU), DummyOp); + +//////////////////////////////////////////////////////////////////////////////// +// +// A SimplePlacerTest method has three phases: +// +// 1. Build a TensorFlow graph, with no (or partial) device assignments. +// 2. Attempt to compute a placement using the SimplePlacer. +// 3. EITHER: test that the constraints implied by the graph are respected; +// or that an appropriate error was reported. +// +//////////////////////////////////////////////////////////////////////////////// +class SimplePlacerTest : public ::testing::Test { + protected: + SimplePlacerTest() { + RequireDefaultOps(); + // Build a set of 10 GPU and 10 CPU devices. + // NOTE: this->local_devices_ owns the device objects; + // this->devices_ contains borrowed pointers to the device + // objects. + for (int i = 0; i < 10; ++i) { + local_devices_.emplace_back(FakeDevice::MakeCPU( + strings::StrCat("/job:a/replica:0/task:0/cpu:", i))); + devices_.AddDevice(local_devices_.back().get()); + // Insert the GPUs in reverse order. + local_devices_.emplace_back(FakeDevice::MakeGPU( + strings::StrCat("/job:a/replica:0/task:0/gpu:", 9 - i))); + devices_.AddDevice(local_devices_.back().get()); + } + } + + // Builds the given graph, and (if successful) indexes the node + // names for use in placement, and later lookup. + Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) { + TF_RETURN_IF_ERROR(builder.ToGraph(out_graph)); + nodes_by_name_.clear(); + for (Node* node : out_graph->nodes()) { + nodes_by_name_[node->name()] = node->id(); + } + return Status::OK(); + } + + // Invokes the SimplePlacer on "graph". If no DeviceSet is specified, the + // placement will use the default DeviceSet (of 10 CPU and 10 GPU devices). + // + // REQUIRES: "*graph" was produced by the most recent call to BuildGraph. + Status Place(Graph* graph, DeviceSet* devices, SessionOptions* options) { + SimplePlacer placer(graph, devices, &nodes_by_name_, options); + return placer.Run(); + } + + Status Place(Graph* graph, DeviceSet* devices) { + return Place(graph, devices, nullptr); + } + + Status Place(Graph* graph, SessionOptions* options) { + return Place(graph, &devices_, options); + } + + Status Place(Graph* graph) { return Place(graph, &devices_, nullptr); } + + // Returns the node in "graph" with the given name. + // + // REQUIRES: "graph" was produced by the most recent call to BuildGraph. + Node* GetNodeByName(const Graph& graph, const string& name) { + const auto search = nodes_by_name_.find(name); + CHECK(search != nodes_by_name_.end()) << "Unknown node name: " << name; + return graph.FindNodeId(search->second); + } + + protected: + std::vector> local_devices_; + DeviceSet devices_; + SimplePlacer::NodeNameToIdMap nodes_by_name_; + + Status ReferenceTestHelper(const string& variable_op_type, + const string& assign_op_type, + DeviceType expected_device_type); +}; + +#define EXPECT_COLOCATED(g, name_a, name_b) \ + do { \ + Graph& g_ = (g); \ + EXPECT_EQ(GetNodeByName(g_, (name_a))->assigned_device_name(), \ + GetNodeByName(g_, (name_b))->assigned_device_name()); \ + } while (0) + +#define EXPECT_DEVICE_TYPE(g, name, expected_device_type) \ + EXPECT_EQ(DeviceType(expected_device_type).type(), \ + devices_.FindDeviceByName( \ + GetNodeByName((g), (name))->assigned_device_name()) \ + ->attributes() \ + .device_type()) + +#define EXPECT_DEVICE_CONTAINS(g, name, device_substr) \ + EXPECT_TRUE(StringPiece(GetNodeByName((g), (name))->assigned_device_name()) \ + .contains(device_substr)) + +// Test that a graph with no constraints will successfully assign nodes to the +// "best available" device (i.e. prefer GPU over CPU). +TEST_F(SimplePlacerTest, TestNoConstraints) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + ops::UnaryOp("TestRelu", ops::NodeOut(input, 0), b.opts().WithName("n1")); + ops::UnaryOp("TestRelu", ops::NodeOut(input, 1), b.opts().WithName("n2")); + EXPECT_OK(BuildGraph(b, &g)); + } + + EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", DEVICE_CPU); + EXPECT_DEVICE_TYPE(g, "n1", DEVICE_GPU); + EXPECT_DEVICE_TYPE(g, "n2", DEVICE_GPU); +} + +// Test that a graph with device type and reference constraints on +// some of the ops will successfully assign nodes to the constrained +// device, and colocate nodes with reference connections. +TEST_F(SimplePlacerTest, TestDeviceTypeConstraints) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); + ops::BinaryOp("AssignCPU", var_cpu, input, b.opts().WithName("assign_cpu")); + Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu")); + ops::BinaryOp("AssignGPU", var_gpu, input, b.opts().WithName("assign_gpu")); + EXPECT_OK(BuildGraph(b, &g)); + } + + EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", DEVICE_CPU); + EXPECT_DEVICE_TYPE(g, "var_cpu", DEVICE_CPU); + EXPECT_DEVICE_TYPE(g, "assign_cpu", DEVICE_CPU); + EXPECT_COLOCATED(g, "var_cpu", "assign_cpu"); + EXPECT_DEVICE_TYPE(g, "var_gpu", DEVICE_GPU); + EXPECT_DEVICE_TYPE(g, "assign_gpu", DEVICE_GPU); + EXPECT_COLOCATED(g, "var_gpu", "assign_gpu"); +} + +// Test that a graph with partial device specifications on the ops +// will successfully +TEST_F(SimplePlacerTest, TestPartialSpec) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:a")); + ops::SourceOp("TestVariable", + b.opts().WithName("var").WithDevice("/job:a")); + EXPECT_OK(BuildGraph(b, &g)); + } + + EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", DEVICE_CPU); + EXPECT_DEVICE_CONTAINS(g, "in", "/job:a"); + EXPECT_DEVICE_TYPE(g, "var", DEVICE_GPU); + EXPECT_DEVICE_CONTAINS(g, "var", "/job:a"); +} + +// Test that a node with an assigned device is not relocated. +TEST_F(SimplePlacerTest, TestAssignedDevicePreserved) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in")); + EXPECT_OK(BuildGraph(b, &g)); + } + + GetNodeByName(g, "in") + ->set_assigned_device_name("/job:a/replica:0/task:0/cpu:7"); + + EXPECT_OK(Place(&g)); + EXPECT_EQ("/job:a/replica:0/task:0/cpu:7", + GetNodeByName(g, "in")->assigned_device_name()); +} + +// Test that a graph with partial device specifications for CPU-only ops +// will be relocated to CPU. +TEST_F(SimplePlacerTest, TestPartialSpecGpuToCpu) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/gpu:0")); + ops::SourceOp("TestVariable", + b.opts().WithName("var").WithDevice("/gpu:0")); + EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + options.config.set_allow_soft_placement(true); + EXPECT_OK(Place(&g, &options)); + EXPECT_DEVICE_TYPE(g, "in", DEVICE_CPU); + EXPECT_DEVICE_CONTAINS(g, "in", "/cpu"); + EXPECT_DEVICE_TYPE(g, "var", DEVICE_GPU); + EXPECT_DEVICE_CONTAINS(g, "var", "/gpu:0"); +} + +// Test that a node with an assigned GPU device but has not registered +// OpKernel will fail. +TEST_F(SimplePlacerTest, TestAssignedGpuDeviceToCpuDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in")); + EXPECT_OK(BuildGraph(b, &g)); + } + + GetNodeByName(g, "in") + ->set_assigned_device_name("/job:a/replica:0/task:0/gpu:0"); + + Status s = Place(&g); + EXPECT_EQ(error::INTERNAL, s.code()); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains("Assigned device '/job:a/replica:0/task:0/gpu:0' " + "does not have registered OpKernel support for TestInput")); +} + +// Test that graphs with reference connections are correctly placed. + +// Build a graph containing a Variable op of "variable_op_type" and an +// Assign op of "assign_op_type", and expect all of the ops to be +// placed on a device of type "expected_device_type". +Status SimplePlacerTest::ReferenceTestHelper(const string& variable_op_type, + const string& assign_op_type, + DeviceType expected_device_type) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + // Build ten variable-and-assignment pairs. + for (int i = 0; i < 10; ++i) { + Node* var = ops::SourceOp(variable_op_type, + b.opts().WithName(strings::StrCat("var_", i))); + ops::BinaryOp(assign_op_type, var, input, + b.opts().WithName(strings::StrCat("assign_", i))); + } + EXPECT_OK(BuildGraph(b, &g)); + } + + TF_RETURN_IF_ERROR(Place(&g)); + + for (int i = 0; i < 10; ++i) { + EXPECT_COLOCATED(g, strings::StrCat("var_", i), + strings::StrCat("assign_", i)); + EXPECT_DEVICE_TYPE(g, strings::StrCat("var_", i), expected_device_type); + EXPECT_DEVICE_TYPE(g, strings::StrCat("assign_", i), expected_device_type); + } + + return Status::OK(); +} + +// Test all 2^3 combinations of Variable and Assignment op types +// (unconstrained, CPU-only, and GPU-only). +TEST_F(SimplePlacerTest, TestReferenceConnection) { + Status s; + EXPECT_OK(ReferenceTestHelper("TestVariable", "TestAssign", DEVICE_GPU)); + EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignCPU", DEVICE_CPU)); + EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignGPU", DEVICE_GPU)); + EXPECT_OK(ReferenceTestHelper("VariableCPU", "TestAssign", DEVICE_CPU)); + EXPECT_OK(ReferenceTestHelper("VariableCPU", "AssignCPU", DEVICE_CPU)); + { + Status s = ReferenceTestHelper("VariableCPU", "AssignGPU", DEVICE_CPU); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("no device type supports both of those nodes")); + } + EXPECT_OK(ReferenceTestHelper("VariableGPU", "TestAssign", DEVICE_GPU)); + { + Status s = ReferenceTestHelper("VariableGPU", "AssignCPU", DEVICE_CPU); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("no device type supports both of those nodes")); + } + EXPECT_OK(ReferenceTestHelper("VariableGPU", "AssignGPU", DEVICE_GPU)); +} + +// Test the handling of '@node_name' colocation constraints, when +// these are arranged in multiple chains. +TEST_F(SimplePlacerTest, TestColocatedChain) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* last_node = input; + for (int i = 0; i < 100; ++i) { + if (i % 10 == 0) { + // Every ten nodes, start a new chain. + last_node = ops::UnaryOp("TestRelu", last_node, + b.opts().WithName(strings::StrCat("n_", i))); + } else { + // Chain each successive node to the previous one. + last_node = + ops::UnaryOp("TestRelu", last_node, + b.opts() + .WithName(strings::StrCat("n_", i)) + .WithDevice(strings::StrCat("@n_", i - 1))); + } + } + EXPECT_OK(BuildGraph(b, &g)); + } + + EXPECT_OK(Place(&g)); + for (int i = 0; i < 100; ++i) { + if (i % 10 != 0) { + EXPECT_COLOCATED(g, strings::StrCat("n_", i - (i % 1)), + strings::StrCat("n_", i)); + } + } +} + +// Test the handling of '@node_name' colocation constraints, when the +// chains are shuffled. +TEST_F(SimplePlacerTest, TestColocatedChainWithLongRangeColocations) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* last_node = input; + for (int i = 0; i < 10; ++i) { + // Start ten chains. + last_node = ops::UnaryOp("TestRelu", last_node, + b.opts().WithName(strings::StrCat("n_", i))); + } + for (int i = 10; i < 100; ++i) { + // Add each node to the (i % 10)^th chain. + last_node = ops::UnaryOp("TestRelu", last_node, + b.opts() + .WithName(strings::StrCat("n_", i)) + .WithDevice(strings::StrCat("@n_", i % 10))); + } + EXPECT_OK(BuildGraph(b, &g)); + } + + EXPECT_OK(Place(&g)); + for (int i = 10; i < 100; ++i) { + EXPECT_COLOCATED(g, strings::StrCat("n_", i % 10), + strings::StrCat("n_", i)); + } +} + +TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + for (int i = 0; i < 10; ++i) { + // Declare ten variable and assignment pairs. + Node* var = ops::SourceOp("TestVariable", + b.opts().WithName(strings::StrCat("var_", i))); + ops::BinaryOp("TestAssign", var, input, + b.opts().WithName(strings::StrCat("assign_", i))); + } + for (int i = 10; i < 100; ++i) { + // Create a variable colocated with some existing variable, and + // an assignment colocated with a possibly-different variable. + Node* var = ops::SourceOp( + "TestVariable", b.opts() + .WithName(strings::StrCat("var_", i)) + .WithDevice(strings::StrCat("@var_", i % 6))); + ops::BinaryOp("TestAssign", var, input, + b.opts() + .WithName(strings::StrCat("assign_", i)) + .WithDevice(strings::StrCat("@assign_", i % 3))); + } + EXPECT_OK(BuildGraph(b, &g)); + } + + EXPECT_OK(Place(&g)); + for (int i = 0; i < 10; ++i) { + EXPECT_COLOCATED(g, strings::StrCat("var_", i), + strings::StrCat("assign_", i)); + } + for (int i = 10; i < 100; ++i) { + EXPECT_COLOCATED(g, strings::StrCat("var_", i), + strings::StrCat("assign_", i)); + EXPECT_COLOCATED(g, strings::StrCat("var_", i), + strings::StrCat("var_", i % 6)); + EXPECT_COLOCATED(g, strings::StrCat("assign_", i), + strings::StrCat("assign_", i % 3)); + } +} + +// Test that placement fails when no devices are registered. +TEST_F(SimplePlacerTest, TestEmptyDeviceSet) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in")); + EXPECT_OK(BuildGraph(b, &g)); + } + + DeviceSet empty; + + Status s = Place(&g, &empty); + EXPECT_TRUE( + StringPiece(s.error_message()).contains("No devices are registered")); +} + +// Test that placement fails when the requested device forces an +// indirect constraint to be violated. +TEST_F(SimplePlacerTest, TestHeterogeneousDeviceSetFailure) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* in = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var")); + ops::BinaryOp("TestAssign", var, in, + b.opts().WithName("assign").WithDevice("/job:b/task:1")); + EXPECT_OK(BuildGraph(b, &g)); + } + + DeviceSet heterogeneous; + std::unique_ptr gpu( + FakeDevice::MakeGPU("/job:b/replica:0/task:0/gpu:0")); + heterogeneous.AddDevice(gpu.get()); + std::unique_ptr cpu( + FakeDevice::MakeCPU("/job:b/replica:0/task:1/cpu:0")); + heterogeneous.AddDevice(cpu.get()); + Status s = Place(&g, &heterogeneous); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("colocated with a group of nodes that required " + "incompatible device")); +} + +// Test that placement fails when an unknown device is requested. +TEST_F(SimplePlacerTest, TestUnknownDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo")); + EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains( + "Could not satisfy explicit device specification '/job:foo'")); +} + +// Test that placement fails when the combination of partial +// constraints leads to an unknown device. +TEST_F(SimplePlacerTest, TestUnknownMergedDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo")); + EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains( + "Could not satisfy explicit device specification '/job:foo'")); +} + +// Test that placement fails when the previously-assigned device for a +// node is unknown. +TEST_F(SimplePlacerTest, TestUnknownAssignedDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in")); + EXPECT_OK(BuildGraph(b, &g)); + } + + GetNodeByName(g, "in")->set_assigned_device_name("/job:foo"); + + Status s = Place(&g); + EXPECT_EQ(error::INTERNAL, s.code()); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains("Assigned device '/job:foo' does not match any device")); +} + +// Test that placement fails when an op with no registered kernels is +// requested. +TEST_F(SimplePlacerTest, TestNoKernelsRegistered) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("VariableNoKernels", b.opts().WithName("var")); + EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains( + "No OpKernel was registered to support Op 'VariableNoKernels'")); +} + +// Test that placement fails when a kernel is registered but no known +// device supports it. +TEST_F(SimplePlacerTest, TestNoDevicesRegistered) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("VariableGPU", b.opts().WithName("var")); + EXPECT_OK(BuildGraph(b, &g)); + } + + DeviceSet cpu_only; + std::unique_ptr cpu( + FakeDevice::MakeCPU("/job:a/replica:0/task:0/cpu:0")); + cpu_only.AddDevice(cpu.get()); + + Status s = Place(&g, &cpu_only); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("No OpKernel was registered to support " + "Op 'VariableGPU'")); +} + +// Test that placement fails when a requested device is malformed. +TEST_F(SimplePlacerTest, TestMalformedDeviceSpecification) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/foo:bar")); + EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("Malformed device specification '/foo:bar'")); +} + +// Test that placement fails when a previously-assigned device is malformed. +TEST_F(SimplePlacerTest, TestMalformedAssignedDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in")); + EXPECT_OK(BuildGraph(b, &g)); + } + + GetNodeByName(g, "in")->set_assigned_device_name("/foo:bar"); + + Status s = Place(&g); + EXPECT_EQ(error::INTERNAL, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("Malformed assigned device '/foo:bar'")); +} + +// Test that placement fails when a device was previously assigned to +// a node, but it does not uniquely identify a particular device. +TEST_F(SimplePlacerTest, TestNonUniqueAssignedDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in")); + EXPECT_OK(BuildGraph(b, &g)); + } + + GetNodeByName(g, "in")->set_assigned_device_name("/job:a"); + + Status s = Place(&g); + EXPECT_EQ(error::INTERNAL, s.code()); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains("Assigned device '/job:a' does not match any device")); +} + +// Test that placement fails when a node requests colocation with another +// node that does not exist. +TEST_F(SimplePlacerTest, TestUnknownColocatedNode) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("@foo")); + EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()).contains("'foo' does not exist")); +} + +// Test that placement fails when a node requests colocation with a +// malformed node name. +TEST_F(SimplePlacerTest, TestMalformedColocatedNode) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("@")); + EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("node named in device '' does not exist")); +} + +// Test that ops request to be placed on non-existent devices will be relocated +// to existing device of the same type if allow_soft_placement is set. +TEST_F(SimplePlacerTest, TestNonexistentGpuAllowSoftPlacement) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestDevice", b.opts().WithName("in").WithDevice("/gpu:11")); + EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + options.config.set_allow_soft_placement(true); + EXPECT_OK(Place(&g, &options)); + EXPECT_DEVICE_CONTAINS(g, "in", "/gpu:0"); +} + +// Test that ops request to be placed on non-existent devices will fail if +// allow_soft_placement is not set. +TEST_F(SimplePlacerTest, TestNonexistentGpuNoAllowSoftPlacement) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestDevice", b.opts().WithName("in").WithDevice("/gpu:11")); + EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + Status s = Place(&g, &options); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains( + "Could not satisfy explicit device specification '/gpu:11'")); +} + +// Test that placement fails when a node requests an explicit device that is not +// supported by the registered kernels if allow_soft_placement is no set. +TEST_F(SimplePlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("VariableGPU", b.opts().WithName("var").WithDevice("/cpu:0")); + EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + Status s = Place(&g, &options); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains( + "Could not satisfy explicit device specification '/cpu:0'")); +} + +TEST_F(SimplePlacerTest, TestUnsupportedDeviceAllowSoftPlacement) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("VariableGPU", b.opts().WithName("var").WithDevice("/cpu:0")); + EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + options.config.set_allow_soft_placement(true); + EXPECT_OK(Place(&g, &options)); +} + +// Test that a graph with device type and reference constraints on +// some of the ops will successfully assign nodes to the constrained +// device, and colocate nodes with reference connections. +TEST_F(SimplePlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + // var_gpu has ref output and runs on GPU. + // force_gpu takes var_gpu and requested CPU. + // Verify that both are placed on GPU. + Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu")); + ops::UnaryOp("TestDeviceEnforce", var_gpu, + b.opts().WithName("force_gpu").WithDevice("/cpu:0")); + // var_cpu has ref output and runs on CPU. + // force_cpu takes var_cpu and requested GPU. + // Verify that both are placed on CPU. + Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); + ops::UnaryOp("TestDeviceEnforce", var_cpu, + b.opts().WithName("force_cpu").WithDevice("/gpu:0")); + EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + options.config.set_allow_soft_placement(true); + EXPECT_OK(Place(&g, &options)); + EXPECT_DEVICE_TYPE(g, "var_gpu", DEVICE_GPU); + EXPECT_DEVICE_TYPE(g, "force_gpu", DEVICE_GPU); + EXPECT_COLOCATED(g, "var_gpu", "force_gpu"); + EXPECT_DEVICE_TYPE(g, "var_cpu", DEVICE_CPU); + EXPECT_DEVICE_TYPE(g, "force_cpu", DEVICE_CPU); + EXPECT_COLOCATED(g, "var_cpu", "force_cpu"); +} + +// Test that placement fails when two nodes have a reference connection +// constraint, and each node requires a mutually incompatible device. +TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var")); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + ops::BinaryOp("AssignCPU", var, input, b.opts().WithName("assign")); + EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("Cannot colocate nodes 'var' and 'assign'")); +} + +// Test that placement fails when two nodes have an explicit +// colocation constraint, and each node requires a mutually +// incompatible device. +TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithColocatedNodes) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", + b.opts().WithName("in").WithDevice("/gpu:0")); + Node* relu_1 = ops::UnaryOp("TestRelu", input, + b.opts().WithName("relu_1").WithDevice("@in")); + ops::UnaryOp("ReluGPU", relu_1, + b.opts().WithName("relu_2").WithDevice("@relu_1")); + EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("Cannot colocate nodes 'relu_1' and 'relu_2'")); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc new file mode 100644 index 0000000000..4806e69c67 --- /dev/null +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -0,0 +1,55 @@ +#include "tensorflow/core/common_runtime/threadpool_device.h" + +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options, + const string& name, Bytes memory_limit, + BusAdjacency bus_adjacency, + Allocator* allocator) + : LocalDevice(options, Device::BuildDeviceAttributes( + name, DEVICE_CPU, memory_limit, bus_adjacency), + allocator), + allocator_(allocator) {} + +ThreadPoolDevice::~ThreadPoolDevice() {} + +void ThreadPoolDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { + if (port::Tracing::IsActive()) { + // TODO(pbar) We really need a useful identifier of the graph node. + const uint64 id = Hash64(op_kernel->name()); + port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute, + id); + op_kernel->Compute(context); + } else { + op_kernel->Compute(context); + } +} + +Allocator* ThreadPoolDevice::GetAllocator(AllocatorAttributes attr) { + return allocator_; +} + +Status ThreadPoolDevice::MakeTensorFromProto( + const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, + Tensor* tensor) { + Tensor parsed(tensor_proto.dtype()); + if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + tensor_proto.DebugString()); + } + *tensor = parsed; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/threadpool_device.h b/tensorflow/core/common_runtime/threadpool_device.h new file mode 100644 index 0000000000..5b0347231f --- /dev/null +++ b/tensorflow/core/common_runtime/threadpool_device.h @@ -0,0 +1,31 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_ +#define TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_ + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/local_device.h" + +namespace tensorflow { + +// CPU device implementation. +class ThreadPoolDevice : public LocalDevice { + public: + ThreadPoolDevice(const SessionOptions& options, const string& name, + Bytes memory_limit, BusAdjacency bus_adjacency, + Allocator* allocator); + ~ThreadPoolDevice() override; + + void Compute(OpKernel* op_kernel, OpKernelContext* context) override; + Allocator* GetAllocator(AllocatorAttributes attr) override; + Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override; + + Status Sync() override { return Status::OK(); } + + private: + Allocator* allocator_; // Not owned +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_ diff --git a/tensorflow/core/common_runtime/threadpool_device_factory.cc b/tensorflow/core/common_runtime/threadpool_device_factory.cc new file mode 100644 index 0000000000..ee6319abad --- /dev/null +++ b/tensorflow/core/common_runtime/threadpool_device_factory.cc @@ -0,0 +1,31 @@ +// Register a factory that provides CPU devices. +#include "tensorflow/core/common_runtime/threadpool_device.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +// TODO(zhifengc/tucker): Figure out the bytes of available RAM. +class ThreadPoolDeviceFactory : public DeviceFactory { + public: + void CreateDevices(const SessionOptions& options, const string& name_prefix, + std::vector* devices) override { + // TODO(zhifengc/tucker): Figure out the number of available CPUs + // and/or NUMA configuration. + int n = 1; + auto iter = options.config.device_count().find("CPU"); + if (iter != options.config.device_count().end()) { + n = iter->second; + } + for (int i = 0; i < n; i++) { + string name = strings::StrCat(name_prefix, "/cpu:", i); + devices->push_back(new ThreadPoolDevice(options, name, Bytes(256 << 20), + BUS_ANY, cpu_allocator())); + } + } +}; +REGISTER_LOCAL_DEVICE_FACTORY("CPU", ThreadPoolDeviceFactory); + +} // namespace tensorflow diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto new file mode 100644 index 0000000000..194d1e7c24 --- /dev/null +++ b/tensorflow/core/example/example.proto @@ -0,0 +1,95 @@ +// Protocol messages for describing input data Examples for machine learning +// model training or inference. +syntax = "proto3"; + +import "tensorflow/core/example/feature.proto"; +// option cc_enable_arenas = true; + +package tensorflow; + +// Example for a movie recommendation application: +// features { +// feature { +// key: "age" +// float_list { +// value: 29.0 +// } +// } +// feature { +// key: "movie" +// bytes_list { +// value: "The Shawshank Redemption" +// value: "Fight Club" +// } +// } +// feature { +// key: "movie_ratings" +// float_list { +// value: 9.0 +// value: 9.7 +// } +// } +// feature { +// key: "suggestion" +// bytes_list { +// value: "Inception" +// } +// } +// # Note that this feature exists to be used as a label in training. +// # E.g., if training a logistic regression model to predict purchase +// # probability in our learning tool we would set the label feature to +// # "suggestion_purchased". +// feature { +// key: "suggestion_purchased" +// float_list { +// value: 1.0 +// } +// } +// # Similar to "suggestion_purchased" above this feature exists to be used +// # as a label in training. +// # E.g., if training a linear regression model to predict purchase +// # price in our learning tool we would set the label feature to +// # "purchase_price". +// feature { +// key: "purchase_price" +// float_list { +// value: 9.99 +// } +// } +// } +// +// A conformant data set obeys the following conventions: +// - If a Feature K exists in one example with data type T, it must be of +// type T in all other examples when present. It may be omitted. +// - The number of instances of Feature K list data may vary across examples, +// depending on the requirements of the model. +// - If a Feature K doesn't exist in an example, a K-specific default will be +// used, if configured. +// - If a Feature K exists in an example but contains no items, the intent +// is considered to be an empty tensor and no default will be used. + +message Example { + Features features = 1; +}; + +// Example representing a ranking instance. +message RankingExample { + Features context = 1; + repeated Features positive = 2; + repeated Features negative = 3; +}; + +// Example representing a sequence. +// The context contains features which apply to the entire sequence. +// Each element in example represents an entry in the sequence. +message SequenceExample { + Features context = 1; + repeated Features features = 2; +}; + +// Example representing a list of feature maps. +// The context contains features which apply to all feature maps. +message InferenceExample { + Features context = 1; + repeated Features features = 2; +}; diff --git a/tensorflow/core/example/feature.proto b/tensorflow/core/example/feature.proto new file mode 100644 index 0000000000..5ab77c2997 --- /dev/null +++ b/tensorflow/core/example/feature.proto @@ -0,0 +1,82 @@ +// Protocol messages for describing features for machine learning model +// training or inference. +// +// There are three base Feature types: +// - bytes +// - float +// - int64 +// +// Base features are contained in Lists which may hold zero or more values. +// +// Features are organized into categories by name. The Features message +// contains the mapping from name to Feature. +// +// Example Features for a movie recommendation application: +// feature { +// key: "age" +// float_list { +// value: 29.0 +// } +// } +// feature { +// key: "movie" +// bytes_list { +// value: "The Shawshank Redemption" +// value: "Fight Club" +// } +// } +// feature { +// key: "movie_ratings" +// float_list { +// value: 9.0 +// value: 9.7 +// } +// } +// feature { +// key: "suggestion" +// bytes_list { +// value: "Inception" +// } +// } +// feature { +// key: "suggestion_purchased" +// int64_list { +// value: 1 +// } +// } +// feature { +// key: "purchase_price" +// float_list { +// value: 9.99 +// } +// } + +syntax = "proto3"; +// option cc_enable_arenas = true; + +package tensorflow; + +message Feature { + // Each feature can be exactly one kind. + oneof kind { + BytesList bytes_list = 1; + FloatList float_list = 2; + Int64List int64_list = 3; + } +}; + +message Features { + // Map from feature name to feature. + map feature = 1; +}; + +// Containers to hold repeated fundamental features. +message BytesList { + repeated bytes value = 1; +} +message FloatList { + repeated float value = 1 [packed=true]; +} +message Int64List { + repeated int64 value = 1 [packed=true]; +} diff --git a/tensorflow/core/framework/allocation_description.proto b/tensorflow/core/framework/allocation_description.proto new file mode 100644 index 0000000000..f6f4bc0126 --- /dev/null +++ b/tensorflow/core/framework/allocation_description.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +message AllocationDescription { + // Total number of bytes requested + int64 requested_bytes = 1; + + // Total number of bytes allocated if known + int64 allocated_bytes = 2; + + // Name of the allocator used + string allocator_name = 3; +}; diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc new file mode 100644 index 0000000000..93f68dcccb --- /dev/null +++ b/tensorflow/core/framework/allocator.cc @@ -0,0 +1,25 @@ +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +Allocator::~Allocator() {} + +class CPUAllocator : public Allocator { + public: + ~CPUAllocator() override {} + + string Name() override { return "cpu"; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + return port::aligned_malloc(num_bytes, alignment); + } + + void DeallocateRaw(void* ptr) override { port::aligned_free(ptr); } +}; + +Allocator* cpu_allocator() { + static CPUAllocator* cpu_alloc = new CPUAllocator; + return cpu_alloc; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h new file mode 100644 index 0000000000..6f162a608c --- /dev/null +++ b/tensorflow/core/framework/allocator.h @@ -0,0 +1,132 @@ +#ifndef TENSORFLOW_FRAMEWORK_ALLOCATOR_H_ +#define TENSORFLOW_FRAMEWORK_ALLOCATOR_H_ + +#include +#include + +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// Allocator is an abstract interface for allocating and deallocating +// device memory. +class Allocator { + public: + virtual ~Allocator(); + + // Return a string identifying this allocator + virtual string Name() = 0; + + // Return an uninitialized block of memory that is "num_bytes" bytes + // in size. The returned pointer is guaranteed to be aligned to a + // multiple of "alignment" bytes. + // REQUIRES: "alignment" is a power of 2. + virtual void* AllocateRaw(size_t alignment, size_t num_bytes) = 0; + + // Deallocate a block of memory pointer to by "ptr" + // REQUIRES: "ptr" was previously returned by a call to AllocateRaw + virtual void DeallocateRaw(void* ptr) = 0; + + // Convenience functions to do typed allocation. Note that these functions + // do not invoke C++ constructors or destructors. May return NULL if the + // tensor has too many elements to represent in a single allocation. + template + T* Allocate(size_t num_elements) { + // TODO(jeff): Do we need to allow clients to pass in alignment + // requirements? + + if (num_elements > (std::numeric_limits::max() / sizeof(T))) { + return NULL; + } + + void* p = AllocateRaw(32 /* align to 32 byte boundary */, + sizeof(T) * num_elements); + return reinterpret_cast(p); + } + + template + void Deallocate(T* ptr) { + DeallocateRaw(ptr); + } + + // Returns true if this allocator tracks the sizes of allocations. + // RequestedSize and AllocatedSize must be overridden if + // TracksAlloctionSizes is overridden to return true. + virtual bool TracksAllocationSizes() { return false; } + + // Returns the user-requested size of the data allocated at + // 'ptr'. Note that the actual buffer allocated might be larger + // than requested, but this function returns the size requested by + // the user. + // + // REQUIRES: TracksAllocationSizes() is true. + // + // REQUIRES: 'ptr!=nullptr' and points to a buffer previously + // allocated by this allocator. + virtual size_t RequestedSize(void* ptr) { + CHECK(false) << "allocator doesn't track sizes"; + } + + // Returns the allocated size of the buffer at 'ptr' if known, + // otherwise returns RequestedSize(ptr). AllocatedSize(ptr) is + // guaranteed to be >= RequestedSize(ptr). + // + // REQUIRES: TracksAllocationSizes() is true. + // + // REQUIRES: 'ptr!=nullptr' and points to a buffer previously + // allocated by this allocator. + virtual size_t AllocatedSize(void* ptr) { return RequestedSize(ptr); } + + // TODO(jeff): Maybe provide some interface to give info about + // current allocation state (total number of bytes available for + // allocation, number of bytes free on device, etc.) +}; + +// A tensorflow Op may need access to different kinds of memory that +// are not simply a function of the device to which the Op has been +// assigned. For example, an Op executing on a GPU may still need +// to allocate CPU RAM for some purpose. Internal to the tensorflow +// runtime we may choose to allocate CPU ram from special regions +// that have been prepared for higher performance in some use +// contexts, e.g. doing DMA with particular devices. For these +// reasons, the Device interface does not expose just one memory +// Allocator, but instead provides an accessor that takes a +// specification of the desired memory attributes in order to select +// an Allocator. +// +// NOTE: The upper 8 bits of the value are reserved for +// device-specific uses. Implementors of a device can interpret these +// upper 8 bits in device-specific ways, and ops implemented for those +// devices are responsible for setting those 8 bits appropriately. +// +// Example use: +// // Allocator for ordinary device memory: +// Allocator* a = allocator(AllocatorAttributes()); +// ... +// // Allocator for CPU RAM, regardless of where Op is executing: +// AllocatorAttributes attr; +// attr.set_on_host(true); +// Allocator* a = allocator(attr); +struct AllocatorAttributes { + void set_on_host(bool v) { value |= (static_cast(v)); } + bool on_host() const { return value & 0x1; } + void set_nic_compatible(bool v) { value |= (static_cast(v) << 1); } + bool nic_compatible() const { return value & (0x1 << 1); } + void set_gpu_compatible(bool v) { value |= (static_cast(v) << 2); } + bool gpu_compatible() const { return value & (0x1 << 2); } + + void Merge(AllocatorAttributes other) { value |= other.value; } + + uint32 value = 0; +}; + +// Returns a trivial implementation of Allocator which uses the system +// default malloc. +Allocator* cpu_allocator(); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_ALLOCATOR_H_ diff --git a/tensorflow/core/framework/allocator_test.cc b/tensorflow/core/framework/allocator_test.cc new file mode 100644 index 0000000000..6b1e52cfc4 --- /dev/null +++ b/tensorflow/core/framework/allocator_test.cc @@ -0,0 +1,61 @@ +#include "tensorflow/core/framework/allocator.h" +#include +#include "tensorflow/core/platform/logging.h" +#include +namespace tensorflow { + +TEST(CPUAllocatorTest, Simple) { + Allocator* a = cpu_allocator(); + std::vector ptrs; + for (int s = 1; s < 1024; s++) { + void* raw = a->AllocateRaw(1, s); + ptrs.push_back(raw); + } + std::sort(ptrs.begin(), ptrs.end()); + for (size_t i = 0; i < ptrs.size(); i++) { + if (i > 0) { + CHECK_NE(ptrs[i], ptrs[i - 1]); // No dups + } + a->DeallocateRaw(ptrs[i]); + } + float* t1 = a->Allocate(1024); + double* t2 = a->Allocate(1048576); + a->Deallocate(t1); + a->Deallocate(t2); +} + +// Define a struct that we will use to observe behavior in the unit tests +struct TestStruct { + int x; // not used just want to make sure sizeof(TestStruct) > 1 +}; + +TEST(CPUAllocatorTest, CheckStructSize) { CHECK_GT(sizeof(TestStruct), 1); } + +TEST(CPUAllocatorTest, AllocateOverflowMaxSizeT) { + Allocator* a = cpu_allocator(); + + // The maximum size_t value will definitely overflow. + size_t count_to_allocate = std::numeric_limits::max(); + TestStruct* const test_pointer = a->Allocate(count_to_allocate); + + CHECK_EQ(test_pointer, reinterpret_cast(NULL)); +} + +TEST(CPUAllocatorTest, AllocateOverflowSmallest) { + Allocator* a = cpu_allocator(); + + // count_to_allocate is the smallest count that will cause overflow. + const size_t count_to_allocate = + (std::numeric_limits::max() / sizeof(TestStruct)) + 1; + TestStruct* const test_pointer = a->Allocate(count_to_allocate); + + CHECK_EQ(test_pointer, reinterpret_cast(NULL)); +} + +TEST(CPUAllocatorTest, Sizes) { + Allocator* a = cpu_allocator(); + + EXPECT_EQ(false, a->TracksAllocationSizes()); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/attr_value.proto b/tensorflow/core/framework/attr_value.proto new file mode 100644 index 0000000000..c6a9940815 --- /dev/null +++ b/tensorflow/core/framework/attr_value.proto @@ -0,0 +1,57 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + // TODO(zhifengc/josh11b): implements list(func) if needed. + } + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc new file mode 100644 index 0000000000..400ef118b8 --- /dev/null +++ b/tensorflow/core/framework/attr_value_util.cc @@ -0,0 +1,382 @@ +#include "tensorflow/core/framework/attr_value_util.h" + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +namespace { + +string SummarizeString(const string& str) { + return strings::StrCat("\"", str_util::CEscape(str), "\""); +} + +string SummarizeShape(const TensorShapeProto& proto) { + TensorShape shape(proto); + return shape.ShortDebugString(); +} + +string SummarizeTensor(const TensorProto& tensor_proto) { + Tensor t; + if (!t.FromProto(tensor_proto)) { + return strings::StrCat(""); + } + return t.DebugString(); +} + +} // namespace + +string SummarizeAttrValue(const AttrValue& attr_value) { + switch (attr_value.value_case()) { + case AttrValue::kS: + return SummarizeString(attr_value.s()); + case AttrValue::kI: + return strings::StrCat(attr_value.i()); + case AttrValue::kF: + return strings::StrCat(attr_value.f()); + case AttrValue::kB: + return attr_value.b() ? "true" : "false"; + case AttrValue::kType: + return DataType_Name(attr_value.type()); + case AttrValue::kShape: + return SummarizeShape(attr_value.shape()); + case AttrValue::kTensor: + return SummarizeTensor(attr_value.tensor()); + case AttrValue::kList: { + string ret = "["; + if (attr_value.list().s_size() > 0) { + for (int i = 0; i < attr_value.list().s_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, SummarizeString(attr_value.list().s(i))); + } + } else if (attr_value.list().i_size() > 0) { + for (int i = 0; i < attr_value.list().i_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, attr_value.list().i(i)); + } + } else if (attr_value.list().f_size() > 0) { + for (int i = 0; i < attr_value.list().f_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, attr_value.list().f(i)); + } + } else if (attr_value.list().b_size() > 0) { + for (int i = 0; i < attr_value.list().b_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false"); + } + } else if (attr_value.list().type_size() > 0) { + for (int i = 0; i < attr_value.list().type_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, DataType_Name(attr_value.list().type(i))); + } + } else if (attr_value.list().shape_size() > 0) { + for (int i = 0; i < attr_value.list().shape_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, SummarizeShape(attr_value.list().shape(i))); + } + } else if (attr_value.list().tensor_size() > 0) { + for (int i = 0; i < attr_value.list().tensor_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, + SummarizeTensor(attr_value.list().tensor(i))); + } + } + strings::StrAppend(&ret, "]"); + return ret; + } + case AttrValue::kFunc: { + std::vector entries; + for (auto p : attr_value.func().attr()) { + entries.push_back( + strings::StrCat(p.first, "=", SummarizeAttrValue(p.second))); + } + sort(entries.begin(), entries.end()); + return strings::StrCat(attr_value.func().name(), "[", + str_util::Join(entries, ", "), "]"); + } + case AttrValue::kPlaceholder: + return strings::StrCat("$", attr_value.placeholder()); + case AttrValue::VALUE_NOT_SET: + return ""; + } + return ""; // Prevent missing return warning +} + +Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { + int num_set = 0; + +#define VALIDATE_FIELD(name, type_string, oneof_case) \ + do { \ + if (attr_value.has_list()) { \ + if (attr_value.list().name##_size() > 0) { \ + if (type != "list(" type_string ")") { \ + return errors::InvalidArgument( \ + "AttrValue had value with type list(" type_string ") when ", \ + type, " expected"); \ + } \ + ++num_set; \ + } \ + } else if (attr_value.value_case() == AttrValue::oneof_case) { \ + if (type != type_string) { \ + return errors::InvalidArgument( \ + "AttrValue had value with type " type_string " when ", type, \ + " expected"); \ + } \ + ++num_set; \ + } \ + } while (false) + + VALIDATE_FIELD(s, "string", kS); + VALIDATE_FIELD(i, "int", kI); + VALIDATE_FIELD(f, "float", kF); + VALIDATE_FIELD(b, "bool", kB); + VALIDATE_FIELD(type, "type", kType); + VALIDATE_FIELD(shape, "shape", kShape); + VALIDATE_FIELD(tensor, "tensor", kTensor); + +#undef VALIDATE_FIELD + + if (attr_value.value_case() == AttrValue::kFunc) { + if (type != "func") { + return errors::InvalidArgument( + "AttrValue had value with type 'func' when ", type, " expected"); + } + ++num_set; + } + + if (attr_value.value_case() == AttrValue::kPlaceholder) { + return errors::InvalidArgument( + "AttrValue had value with unexpected type 'placeholder"); + } + + // If the attr type is 'list', we expect attr_value.has_list() to be true. + // However, proto3's attr_value.has_list() can be false when set to an empty + // list. So we simply check if has_list is false and some other field in + // attr_value is set to flag the error. + if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) { + if (num_set) { + return errors::InvalidArgument( + "AttrValue missing value with expected type ", type); + } else { + // Indicate that we have a list, but an empty one. + ++num_set; + } + } + + // Okay to have an empty list, but not to be missing a non-list value. + if (num_set == 0 && !StringPiece(type).starts_with("list(")) { + return errors::InvalidArgument( + "AttrValue missing value with expected type ", type); + } + + // Ref types and DT_INVALID are illegal. + if (type == "type") { + if (IsRefType(attr_value.type())) { + return errors::InvalidArgument( + "AttrValue must not have reference type value of ", + DataTypeString(attr_value.type())); + } + if (attr_value.type() == DT_INVALID) { + return errors::InvalidArgument("AttrValue has invalid DataType"); + } + } else if (type == "list(type)") { + for (auto as_int : attr_value.list().type()) { + const DataType dtype = static_cast(as_int); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "AttrValue must not have reference type value of ", + DataTypeString(dtype)); + } + if (dtype == DT_INVALID) { + return errors::InvalidArgument("AttrValue contains invalid DataType"); + } + } + } + + return Status::OK(); +} + +bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { + // Parse type. + string field_name; + bool is_list = type.Consume("list("); + if (type.Consume("string")) { + field_name = "s"; + } else if (type.Consume("int")) { + field_name = "i"; + } else if (type.Consume("float")) { + field_name = "f"; + } else if (type.Consume("bool")) { + field_name = "b"; + } else if (type.Consume("type")) { + field_name = "type"; + } else if (type.Consume("shape")) { + field_name = "shape"; + } else if (type.Consume("tensor")) { + field_name = "tensor"; + } else if (type.Consume("func")) { + field_name = "func"; + } else if (type.Consume("placeholder")) { + field_name = "placeholder"; + } else { + return false; + } + if (is_list && !type.Consume(")")) { + return false; + } + + // Construct a valid text proto message to parse. + string to_parse; + if (is_list) { + // TextFormat parser considers "i: 7" to be the same as "i: [7]", + // but we only want to allow list values with []. + if (!RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[.*\\]\\s*")) { + return false; + } + if (RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[\\s*\\]\\s*")) { + // User wrote "[]", so return empty list without invoking the TextFormat + // parse which returns an error for "i: []". + out->Clear(); + out->mutable_list(); + return true; + } + to_parse = strings::StrCat("list { ", field_name, ": ", text, " }"); + } else { + to_parse = strings::StrCat(field_name, ": ", text); + } + + // Parse if we can. + return protobuf::TextFormat::ParseFromString(to_parse, out); +} + +#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ + void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); } + +#define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \ + void SetAttrValue(ARG_TYPE value, AttrValue* out) { \ + out->mutable_list(); /* create list() even if value empty */ \ + for (const auto& v : value) { \ + out->mutable_list()->add_##FIELD(v); \ + } \ + } + +#define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \ + DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ + DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice, FIELD) + +DEFINE_SET_ATTR_VALUE_ONE(const string&, s) +DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice, s) +DEFINE_SET_ATTR_VALUE_BOTH(const char*, s) +DEFINE_SET_ATTR_VALUE_BOTH(int64, i) +DEFINE_SET_ATTR_VALUE_BOTH(int32, i) +DEFINE_SET_ATTR_VALUE_BOTH(float, f) +DEFINE_SET_ATTR_VALUE_BOTH(double, f) +DEFINE_SET_ATTR_VALUE_BOTH(bool, b) +DEFINE_SET_ATTR_VALUE_LIST(const std::vector&, b) +DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list, b) +DEFINE_SET_ATTR_VALUE_BOTH(DataType, type) + +void SetAttrValue(StringPiece value, AttrValue* out) { + out->set_s(value.data(), value.size()); +} + +void SetAttrValue(const TensorShape& value, AttrValue* out) { + value.AsProto(out->mutable_shape()); +} + +void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { + out->mutable_list(); // Create list() even if value empty. + for (const auto& v : value) { + v.AsProto(out->mutable_list()->add_shape()); + } +} + +void SetAttrValue(const Tensor& value, AttrValue* out) { + if (value.NumElements() > 1) { + value.AsProtoTensorContent(out->mutable_tensor()); + } else { + value.AsProtoField(out->mutable_tensor()); + } +} + +void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { + out->mutable_list(); // Create list() even if value empty. + for (const auto& v : value) { + if (v.NumElements() > 1) { + v.AsProtoTensorContent(out->mutable_list()->add_tensor()); + } else { + v.AsProtoField(out->mutable_list()->add_tensor()); + } + } +} + +void SetAttrValue(const TensorProto& value, AttrValue* out) { + *out->mutable_tensor() = value; +} + +void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { + out->mutable_list(); // Create list() even if value empty. + for (const auto& v : value) { + *out->mutable_list()->add_tensor() = v; + } +} + +void SetAttrValue(const NameAttrList& value, AttrValue* out) { + *out->mutable_func() = value; +} + +bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) { + string a_str, b_str; + a.SerializeToString(&a_str); + b.SerializeToString(&b_str); + // Note: it should be safe to compare proto serializations of the attr + // values since at most one field should be set in each (indeed, it + // must be the same field if they are to compare equal). + // Exception: there are multiple equivalent representations of + // TensorProtos. So a return value of true implies a == b, but not the + // converse. + return a_str == b_str; +} + +bool HasPlaceHolder(const AttrValue& val) { + switch (val.value_case()) { + case AttrValue::kFunc: + for (const auto& p : val.func().attr()) { + if (HasPlaceHolder(p.second)) { + return true; + } + } + break; + case AttrValue::kPlaceholder: + return true; + default: + break; + } + return false; +} + +bool SubstitutePlaceholders(SubstituteFunc substitute, AttrValue* value) { + switch (value->value_case()) { + case AttrValue::kFunc: + for (auto& p : *(value->mutable_func()->mutable_attr())) { + if (!SubstitutePlaceholders(substitute, &p.second)) { + return false; + } + } + break; + case AttrValue::kPlaceholder: + return substitute(value->placeholder(), value); + case AttrValue::VALUE_NOT_SET: + return false; + default: + break; + } + return true; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h new file mode 100644 index 0000000000..1faf74a327 --- /dev/null +++ b/tensorflow/core/framework/attr_value_util.h @@ -0,0 +1,83 @@ +#ifndef TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_ + +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// A human-readable rendering of attr_value, that is more concise than a +// text-format proto. +string SummarizeAttrValue(const AttrValue& attr_value); + +// Generates an error if attr_value doesn't have the indicated attr type. +Status AttrValueHasType(const AttrValue& attr_value, StringPiece type); + +// Converts a text proto value from "text" into the the field of *out +// indicated by "type" (e.g. from the type field of an AttrDef). +// Examples: +// * If type:"int" and text:"-14", then *out is set to "i: -14" +// * If type:"list(string)" and text:"['foo', 'bar']", +// then *out is set to "list { s: ['foo', 'bar'] }" +// Returns true on success. +bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out); + +// Sets *out based on the type of value. +void SetAttrValue(const string& value, AttrValue* out); +void SetAttrValue(const char* value, AttrValue* out); +void SetAttrValue(StringPiece value, AttrValue* out); +void SetAttrValue(int64 value, AttrValue* out); +void SetAttrValue(int32 value, AttrValue* out); +void SetAttrValue(float value, AttrValue* out); +void SetAttrValue(double value, AttrValue* out); +void SetAttrValue(bool value, AttrValue* out); +void SetAttrValue(DataType value, AttrValue* out); +void SetAttrValue(const TensorShape& value, AttrValue* out); +void SetAttrValue(const Tensor& value, AttrValue* out); +void SetAttrValue(const TensorProto& value, AttrValue* out); +void SetAttrValue(const NameAttrList& value, AttrValue* out); + +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(const std::vector& value, AttrValue* out); +void SetAttrValue(std::initializer_list value, AttrValue* out); +void SetAttrValue(DataTypeSlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); + +inline void SetAttrValue(const AttrValue& value, AttrValue* out) { + *out = value; +} + +// Returns true if a and b have the same value. +// NOTE: May return false negatives for tensor values. +bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b); + +// Returns true if "val" has a placeholder. +bool HasPlaceHolder(const AttrValue& val); + +// SubstitutePlaceholders recursively replaces placeholders in 'value' +// with an attr value by calling SubstituteFunc. Returns true iff all +// placeholders in "value" are replaced with a value. +// +// SubstituteFunc is given a placeholder string. If the placeholder is +// unknown, SubstituteFunc returns false. Otherwise, overwrites the +// attr value and returns true. +typedef std::function SubstituteFunc; +bool SubstitutePlaceholders(SubstituteFunc substitute, AttrValue* value); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_ diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc new file mode 100644 index 0000000000..bdfbf1707a --- /dev/null +++ b/tensorflow/core/framework/attr_value_util_test.cc @@ -0,0 +1,91 @@ +#include "tensorflow/core/framework/attr_value_util.h" + +#include + +namespace tensorflow { + +// A few helpers to construct AttrValue protos. +template +AttrValue V(T value) { + AttrValue ret; + SetAttrValue(value, &ret); + return ret; +} + +AttrValue P(const string& p) { + AttrValue ret; + ret.set_placeholder(p); + return ret; +} + +AttrValue F(const string& name, + std::vector > pairs) { + AttrValue ret; + ret.mutable_func()->set_name(name); + ret.mutable_func()->mutable_attr()->insert(pairs.begin(), pairs.end()); + return ret; +} + +TEST(AttrValueUtil, HasType) { + // OK + EXPECT_TRUE(AttrValueHasType(V(123), "int").ok()); + EXPECT_TRUE(AttrValueHasType(V(1.2), "float").ok()); + EXPECT_TRUE(AttrValueHasType(V(DT_FLOAT), "type").ok()); + EXPECT_TRUE(AttrValueHasType(F("f", {}), "func").ok()); + + // not OK. + EXPECT_FALSE(AttrValueHasType(V(123), "func").ok()); + EXPECT_FALSE(AttrValueHasType(V(1.2), "int").ok()); + EXPECT_FALSE(AttrValueHasType(V(DT_FLOAT), "shape").ok()); + EXPECT_FALSE(AttrValueHasType(F("f", {}), "string").ok()); + EXPECT_FALSE(AttrValueHasType(P("T"), "float").ok()); +} + +SubstituteFunc ReplaceTWith(const AttrValue& val) { + return [val](const string& placeholder, AttrValue* target) { + if (placeholder == "T") { + *target = val; + return true; + } else { + return false; + } + }; +} + +TEST(AttrValueUtil, Basic) { + auto v = F("MatMul", {{"dtype", P("T")}, + {"transpose_a", V(false)}, + {"transpose_b", V(true)}, + {"use_cublas", V(true)}}); + TF_CHECK_OK(AttrValueHasType(v, "func")); + EXPECT_TRUE(HasPlaceHolder(v)); + + EXPECT_EQ( + SummarizeAttrValue(v), + "MatMul[dtype=$T, transpose_a=false, transpose_b=true, use_cublas=true]"); + + SubstitutePlaceholders(ReplaceTWith(V(DT_FLOAT)), &v); + EXPECT_TRUE(!HasPlaceHolder(v)); + EXPECT_EQ(SummarizeAttrValue(v), + "MatMul[dtype=DT_FLOAT, transpose_a=false, transpose_b=true, " + "use_cublas=true]"); +} + +TEST(AttrValueUtil, DeepAttr) { + auto v = F("f", {{"T", P("T")}}); + TF_CHECK_OK(AttrValueHasType(v, "func")); + EXPECT_TRUE(HasPlaceHolder(v)); + + for (int i = 0; i < 3; ++i) { + v = F("f", {{"T", P("T")}, {"F", v}}); + EXPECT_TRUE(HasPlaceHolder(v)); + } + EXPECT_EQ(SummarizeAttrValue(v), "f[F=f[F=f[F=f[T=$T], T=$T], T=$T], T=$T]"); + + SubstitutePlaceholders(ReplaceTWith(F("x", {})), &v); + EXPECT_TRUE(!HasPlaceHolder(v)); + EXPECT_EQ(SummarizeAttrValue(v), + "f[F=f[F=f[F=f[T=x[]], T=x[]], T=x[]], T=x[]]"); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/bfloat16.cc b/tensorflow/core/framework/bfloat16.cc new file mode 100644 index 0000000000..0068283367 --- /dev/null +++ b/tensorflow/core/framework/bfloat16.cc @@ -0,0 +1,22 @@ +#include "tensorflow/core/framework/bfloat16.h" + +namespace tensorflow { + +void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) { + const uint16_t* p = reinterpret_cast(src); + uint16_t* q = reinterpret_cast(dst); + for (; size; p += 2, q++, size--) { + *q = p[1]; + } +} + +void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) { + const uint16_t* p = reinterpret_cast(src); + uint16_t* q = reinterpret_cast(dst); + for (; size; p++, q += 2, size--) { + q[0] = 0; + q[1] = *p; + } +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/bfloat16.h b/tensorflow/core/framework/bfloat16.h new file mode 100644 index 0000000000..9cd260ee13 --- /dev/null +++ b/tensorflow/core/framework/bfloat16.h @@ -0,0 +1,58 @@ +#ifndef TENSORFLOW_FRAMEWORK_BFLOAT16_H_ +#define TENSORFLOW_FRAMEWORK_BFLOAT16_H_ + +#include "tensorflow/core/platform/port.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +// Compact 16-bit encoding of floating point numbers. This representation uses +// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. It +// is assumed that floats are in IEEE 754 format so the representation is just +// bits 16-31 of a single precision float. +// +// NOTE: The IEEE floating point standard defines a float16 format that +// is different than this format (it has fewer bits of exponent and more +// bits of mantissa). We don't use that format here because conversion +// to/from 32-bit floats is more complex for that format, and the +// conversion for this format is very simple. +// +// Because of the existing IEEE float16 type, we do not name our representation +// "float16" but just use "uint16". +// +// <-----our 16bits float-------> +// s e e e e e e e e f f f f f f f f f f f f f f f f f f f f f f f +// <------------------------------float--------------------------> +// 3 3 2 2 1 1 0 +// 1 0 3 2 5 4 0 +// +// +// This type only supports conversion back and forth with float. +// +// This file must be compilable by nvcc. + +namespace tensorflow { +struct bfloat16 { + EIGEN_DEVICE_FUNC bfloat16() {} + EIGEN_DEVICE_FUNC explicit bfloat16(const uint16_t v) : value(v) {} + + uint16_t value; +}; + +// Conversion routines between an array of float and bfloat16 of +// "size". +void FloatToBFloat16(const float* src, bfloat16* dst, int64 size); +void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size); + +} // namespace tensorflow + +namespace Eigen { +template <> +struct NumTraits : GenericNumTraits {}; + +EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a, + const tensorflow::bfloat16 b) { + return a.value == b.value; +} + +} // namespace Eigen + +#endif // TENSORFLOW_FRAMEWORK_BFLOAT16_H_ diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc new file mode 100644 index 0000000000..4fe791fdeb --- /dev/null +++ b/tensorflow/core/framework/bfloat16_test.cc @@ -0,0 +1,69 @@ +#include "tensorflow/core/framework/bfloat16.h" + +#include "tensorflow/core/platform/test_benchmark.h" +#include + +namespace tensorflow { +namespace { + +TEST(Bfloat16Test, Simple) { + bfloat16 a(12); + EXPECT_EQ(12, a.value); +} + +TEST(Bfloat16Test, Conversion) { + float a[100]; + for (int i = 0; i < 100; ++i) { + a[i] = i + 1.25; + } + bfloat16 b[100]; + float c[100]; + FloatToBFloat16(a, b, 100); + BFloat16ToFloat(b, c, 100); + for (int i = 0; i < 100; ++i) { + // The relative error should be less than 1/(2^7) since bfloat16 + // has 7 bits mantissa. + EXPECT_LE(fabs(c[i] - a[i]) / a[i], 1.0 / 128); + } +} + +static void BM_FloatToBFloat16(int iters) { + testing::StopTiming(); + static const int N = 32 << 20; + const int64 tot = static_cast(iters) * N; + testing::ItemsProcessed(tot); + testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16))); + + float* inp = new float[N]; + bfloat16* out = new bfloat16[N]; + + testing::StartTiming(); + while (iters--) { + FloatToBFloat16(inp, out, N); + } + delete[] inp; + delete[] out; +} +BENCHMARK(BM_FloatToBFloat16); + +static void BM_BFloat16ToFloat(int iters) { + testing::StopTiming(); + static const int N = 32 << 20; + const int64 tot = static_cast(iters) * N; + testing::ItemsProcessed(tot); + testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16))); + + bfloat16* inp = new bfloat16[N]; + float* out = new float[N]; + + testing::StartTiming(); + while (iters--) { + BFloat16ToFloat(inp, out, N); + } + delete[] inp; + delete[] out; +} +BENCHMARK(BM_BFloat16ToFloat); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc new file mode 100644 index 0000000000..51423792a8 --- /dev/null +++ b/tensorflow/core/framework/cancellation.cc @@ -0,0 +1,79 @@ +#include "tensorflow/core/framework/cancellation.h" + +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +const CancellationToken CancellationManager::kInvalidToken = -1; + +CancellationManager::CancellationManager() + : is_cancelling_(false), is_cancelled_(0), next_cancellation_token_(0) {} + +void CancellationManager::StartCancel() { + std::unordered_map callbacks_to_run; + { + mutex_lock l(mu_); + if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) { + return; + } + is_cancelling_ = true; + std::swap(callbacks_, callbacks_to_run); + } + // We call these callbacks without holding mu_, so that concurrent + // calls to DeregisterCallback, which can happen asynchronously, do + // not block. The callbacks remain valid because any concurrent call + // to DeregisterCallback will block until the + // cancelled_notification_ is notified. + for (auto key_and_value : callbacks_to_run) { + key_and_value.second(); + } + { + mutex_lock l(mu_); + is_cancelling_ = false; + is_cancelled_.store(true, std::memory_order_release); + } + cancelled_notification_.Notify(); +} + +CancellationToken CancellationManager::get_cancellation_token() { + mutex_lock l(mu_); + return next_cancellation_token_++; +} + +bool CancellationManager::RegisterCallback(CancellationToken token, + CancelCallback callback) { + mutex_lock l(mu_); + CHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token"; + bool should_register = !is_cancelled_ && !is_cancelling_; + if (should_register) { + std::swap(callbacks_[token], callback); + } + return should_register; +} + +bool CancellationManager::DeregisterCallback(CancellationToken token) { + mu_.lock(); + if (is_cancelled_) { + mu_.unlock(); + return false; + } else if (is_cancelling_) { + mu_.unlock(); + // Wait for all of the cancellation callbacks to be called. This + // wait ensures that the caller of DeregisterCallback does not + // return immediately and free objects that may be used in the + // execution of any currently pending callbacks in StartCancel. + cancelled_notification_.WaitForNotification(); + return false; + } else { + callbacks_.erase(token); + mu_.unlock(); + return true; + } +} + +CancellationManager::~CancellationManager() { StartCancel(); } + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h new file mode 100644 index 0000000000..feda548e97 --- /dev/null +++ b/tensorflow/core/framework/cancellation.h @@ -0,0 +1,121 @@ +#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_ +#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_ + +#include +#include +#include + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// A token that can be used to register and deregister a +// CancelCallback with a CancellationManager. +// +// CancellationToken values must be created by a call to +// CancellationManager::get_cancellation_token. +typedef int64 CancellationToken; + +// A callback that is invoked when a step is cancelled. +// +// NOTE(mrry): See caveats about CancelCallback implementations in the +// comment for CancellationManager::RegisterCallback. +typedef std::function CancelCallback; + +class CancellationManager { + public: + // A value that won't be returned by get_cancellation_token(). + static const CancellationToken kInvalidToken; + + CancellationManager(); + ~CancellationManager(); + + // Run all callbacks associated with this manager. + void StartCancel(); + + // Returns true iff StartCancel() has been called. + bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); } + + // Returns a token that must be used in calls to RegisterCallback + // and DeregisterCallback. + CancellationToken get_cancellation_token(); + + // Attempts to register the given callback to be invoked when this + // manager is cancelled. Returns true if the callback was + // registered; returns false if this manager was already cancelled, + // and the callback was not registered. + // + // If this method returns false, it is the caller's responsibility + // to perform any cancellation cleanup. + // + // This method is tricky to use correctly. The following usage pattern + // is recommended: + // + // class ObjectWithCancellableOperation { + // mutex mu_; + // void CancellableOperation(CancellationManager* cm, + // std::function callback) { + // bool already_cancelled; + // CancellationToken token = cm->get_cancellation_token(); + // { + // mutex_lock(mu_); + // already_cancelled = cm->RegisterCallback( + // [this, token]() { Cancel(token); }); + // if (!already_cancelled) { + // // Issue asynchronous operation. Associate the pending operation + // // with `token` in some object state, or provide another way for + // // the Cancel method to look up the operation for cancellation. + // // Ensure that `cm->DeregisterCallback(token)` is called without + // // holding `mu_`, before `callback` is invoked. + // // ... + // } + // } + // if (already_cancelled) { + // callback(errors::Cancelled("Operation was cancelled")); + // } + // } + // + // void Cancel(CancellationToken token) { + // mutex_lock(mu_); + // // Take action to cancel the operation with the given cancellation + // // token. + // } + // + // NOTE(mrry): The caller should take care that (i) the calling code + // is robust to `callback` being invoked asynchronously (e.g. from + // another thread), (ii) `callback` is deregistered by a call to + // this->DeregisterCallback(token) when the operation completes + // successfully, and (iii) `callback` does not invoke any method + // on this cancellation manager. Furthermore, it is important that + // the eventual caller of the complementary DeregisterCallback does not + // hold any mutexes that are required by `callback`. + bool RegisterCallback(CancellationToken token, CancelCallback callback); + + // Deregister the callback that, when registered, was associated + // with the given cancellation token. Returns true iff the callback + // was deregistered and will not be invoked; otherwise returns false + // after the callback has been invoked, blocking if necessary. + // + // NOTE(mrry): This method may block if cancellation is in progress. + // The caller of this method must not hold any mutexes that are required + // to invoke any cancellation callback that has been registered with this + // cancellation manager. + bool DeregisterCallback(CancellationToken token); + + private: + bool is_cancelling_; + std::atomic_bool is_cancelled_; + + mutex mu_; + Notification cancelled_notification_; + CancellationToken next_cancellation_token_ GUARDED_BY(mu_); + std::unordered_map callbacks_ + GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_CANCELLATION_H_ diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc new file mode 100644 index 0000000000..1925dd20cc --- /dev/null +++ b/tensorflow/core/framework/cancellation_test.cc @@ -0,0 +1,102 @@ +#include "tensorflow/core/framework/cancellation.h" + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include + +namespace tensorflow { + +TEST(Cancellation, SimpleNoCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + bool deregistered = manager->DeregisterCallback(token); + EXPECT_TRUE(deregistered); + delete manager; + EXPECT_FALSE(is_cancelled); +} + +TEST(Cancellation, SimpleCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + manager->StartCancel(); + EXPECT_TRUE(is_cancelled); + delete manager; +} + +TEST(Cancellation, CancelBeforeRegister) { + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + manager->StartCancel(); + bool registered = manager->RegisterCallback(token, nullptr); + EXPECT_FALSE(registered); + delete manager; +} + +TEST(Cancellation, DeregisterAfterCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + manager->StartCancel(); + EXPECT_TRUE(is_cancelled); + bool deregistered = manager->DeregisterCallback(token); + EXPECT_FALSE(deregistered); + delete manager; +} + +TEST(Cancellation, CancelMultiple) { + bool is_cancelled_1 = false, is_cancelled_2 = false, is_cancelled_3 = false; + CancellationManager* manager = new CancellationManager(); + auto token_1 = manager->get_cancellation_token(); + bool registered_1 = manager->RegisterCallback( + token_1, [&is_cancelled_1]() { is_cancelled_1 = true; }); + EXPECT_TRUE(registered_1); + auto token_2 = manager->get_cancellation_token(); + bool registered_2 = manager->RegisterCallback( + token_2, [&is_cancelled_2]() { is_cancelled_2 = true; }); + EXPECT_TRUE(registered_2); + EXPECT_FALSE(is_cancelled_1); + EXPECT_FALSE(is_cancelled_2); + manager->StartCancel(); + EXPECT_TRUE(is_cancelled_1); + EXPECT_TRUE(is_cancelled_2); + EXPECT_FALSE(is_cancelled_3); + auto token_3 = manager->get_cancellation_token(); + bool registered_3 = manager->RegisterCallback( + token_3, [&is_cancelled_3]() { is_cancelled_3 = true; }); + EXPECT_FALSE(registered_3); + EXPECT_FALSE(is_cancelled_3); + delete manager; +} + +TEST(Cancellation, IsCancelled) { + CancellationManager* cm = new CancellationManager(); + thread::ThreadPool w(Env::Default(), "test", 4); + std::vector done(8); + for (size_t i = 0; i < done.size(); ++i) { + Notification* n = &done[i]; + w.Schedule([n, cm]() { + while (!cm->IsCancelled()) { + } + n->Notify(); + }); + } + Env::Default()->SleepForMicroseconds(1000000 /* 1 second */); + cm->StartCancel(); + for (size_t i = 0; i < done.size(); ++i) { + done[i].WaitForNotification(); + } + delete cm; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/config.proto b/tensorflow/core/framework/config.proto new file mode 100644 index 0000000000..f0def3d6d7 --- /dev/null +++ b/tensorflow/core/framework/config.proto @@ -0,0 +1,61 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +message GPUOptions { + // A value between 0 and 1 that indicates what fraction of the + // available GPU memory to pre-allocate for each process. 1 means + // to pre-allocate all of the GPU memory, 0.5 means the process + // allocates ~50% of the available GPU memory. + double per_process_gpu_memory_fraction = 1; +}; + +// Session configuration parameters. +// The system picks an appropriate values for fields that are not set. +message ConfigProto { + // Map from device type name (e.g., "CPU" or "GPU" ) to maximum + // number of devices of that type to use. If a particular device + // type is not found in the map, the system picks an appropriate + // number. + map device_count = 1; + + // The execution of an individual op (for some op types) can be + // parallelized on a pool of intra_op_parallelism_threads. + // 0 means the system picks an appropriate number. + int32 intra_op_parallelism_threads = 2; + + // Nodes that perform blocking operations are enqueued on a pool of + // inter_op_parallelism_threads available in each process. + // + // 0 means the system picks an appropriate number. + // + // Note that the first Session created in the process sets the + // number of threads for all future sessions. + int32 inter_op_parallelism_threads = 5; + + // Assignment of Nodes to Devices is recomputed every placement_period + // steps until the system warms up (at which point the recomputation + // typically slows down automatically). + int32 placement_period = 3; + + // When any filters are present sessions will ignore all devices which do not + // match the filters. Each filter can be partially specified, e.g. "/job:ps" + // "/job:worker/replica:3", etc. + repeated string device_filters = 4; + + // Options that apply to all GPUs. + GPUOptions gpu_options = 6; + + // Whether soft placement is allowed. If allow_soft_placement is true, + // an op will be placed on CPU if + // 1. there's no GPU implementation for the OP + // or + // 2. no GPU devices are known or registered + // or + // 3. need to co-locate with reftype input(s) which are from CPU. + bool allow_soft_placement = 7; + + // Whether device placements should be logged. + bool log_device_placement = 8; +}; diff --git a/tensorflow/core/framework/control_flow.h b/tensorflow/core/framework/control_flow.h new file mode 100644 index 0000000000..f59e0f5310 --- /dev/null +++ b/tensorflow/core/framework/control_flow.h @@ -0,0 +1,43 @@ +#ifndef TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_ +#define TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { + +const uint64 kIllegalFrameId = ~0uLL; +const int64 kIllegalIterId = -1; + +// For the purpose of control flow, every tensor produced by TensorFlow is +// conceptually tagged by a 'FrameAndIter'. FrameAndIter consists of a +// 'frame_id' and an 'iter_id'. The tensor value it represents is produced +// in the frame with frame_id at the iteration of iter_id. +struct FrameAndIter { + uint64 frame_id = kIllegalFrameId; + int64 iter_id = kIllegalIterId; + + FrameAndIter() {} + + FrameAndIter(uint64 frame, int64 iter) { + frame_id = frame; + iter_id = iter; + } + + bool operator==(const FrameAndIter& other) const { + return (frame_id == other.frame_id && iter_id == other.iter_id); + } +}; + +struct FrameAndIterHash { + size_t operator()(const FrameAndIter& key) const { + // Make sure there are no padding bytes that we don't want + CHECK_EQ(sizeof(uint64) + sizeof(int64), sizeof(FrameAndIter)); + return Hash64(reinterpret_cast(&key), sizeof(FrameAndIter)); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_ diff --git a/tensorflow/core/framework/device_attributes.proto b/tensorflow/core/framework/device_attributes.proto new file mode 100644 index 0000000000..7592215d1e --- /dev/null +++ b/tensorflow/core/framework/device_attributes.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +// BusAdjacency identifies the ability of a device to participate in +// maximally efficient DMA operations within the local context of a +// process. +// +// This is currently ignored. +enum BusAdjacency { + BUS_0 = 0; + BUS_1 = 1; + BUS_ANY = 2; + BUS_NUM_ADJACENCIES = 3; +}; + +message DeviceAttributes { + string name = 1; + + // String representation of device_type. + string device_type = 2; + + // Memory capacity of device in bytes. + int64 memory_limit = 4; + + BusAdjacency bus_adjacency = 5; + + // A device is assigned a global unique number each time it is + // initialized. "incarnation" should never be 0. + fixed64 incarnation = 6; + + // String representation of the physical device that this device maps to. + string physical_device_desc = 7; +} diff --git a/tensorflow/core/framework/device_base.cc b/tensorflow/core/framework/device_base.cc new file mode 100644 index 0000000000..83ad199062 --- /dev/null +++ b/tensorflow/core/framework/device_base.cc @@ -0,0 +1,7 @@ +#include "tensorflow/core/framework/device_base.h" + +namespace tensorflow { + +DeviceBase::~DeviceBase() {} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h new file mode 100644 index 0000000000..ed4ffc5d94 --- /dev/null +++ b/tensorflow/core/framework/device_base.h @@ -0,0 +1,172 @@ +#ifndef TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_ +#define TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_ + +#include +#include + +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/public/status.h" + +namespace Eigen { +class ThreadPoolDevice; +} // end namespace Eigen + +namespace perftools { +namespace gputools { +class Stream; +} // namespace gputools +} // namespace perftools + +namespace tensorflow { + +class Device; +class Env; +class EventMgr; + +namespace thread { +class ThreadPool; +} + +// A wrapper for an Eigen Gpu Device that includes per-op state +class PerOpGpuDevice { + public: + virtual ~PerOpGpuDevice() {} + virtual const Eigen::GpuDevice& device() const = 0; +}; + +// A class that devices can subclass to pass around +// Device-specific context to OpKernels. +class DeviceContext : public core::RefCounted { + public: + ~DeviceContext() override {} + virtual perftools::gputools::Stream* stream() const { return nullptr; } + virtual void MaintainLifetimeOnStream( + const Tensor* t, perftools::gputools::Stream* stream) const {} + + // "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into + // "device_tensor" which is on a GPU device "device". "device_tensor" + // must be allocated to be of the same size as "cpu_tensor". + virtual void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, + StatusCallback done) const { + done(errors::Internal("Unrecognized device type in CPU-to-device Copy")); + } + + // "device_tensor" is a tensor on a non-CPU device. Copies + // device_tensor into "cpu_tensor". "cpu_tensor" must be allocated + // to be of the same size as "device_tensor". + virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor, + const string& tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) { + done(errors::Internal("Unrecognized device type in device-to-CPU Copy")); + } +}; + +typedef std::unordered_map DeviceContextMap; + +class DeviceBase { + public: + explicit DeviceBase(Env* env) : env_(env) {} + virtual ~DeviceBase(); + + Env* env() const { return env_; } + + // Override this to return true for devices that require an Op's + // compute method to save references to the temporary tensors it + // allocates until the Op execution completes + virtual bool SaveTemporaryTensors() const { return false; } + + struct CpuWorkerThreads { + int num_threads = 0; + thread::ThreadPool* workers = nullptr; + }; + + // Does not take ownership. + void set_tensorflow_cpu_worker_threads(CpuWorkerThreads* t) { + cpu_worker_threads_ = t; + } + + const CpuWorkerThreads* tensorflow_cpu_worker_threads() const { + CHECK(cpu_worker_threads_ != nullptr); + return cpu_worker_threads_; + } + + // "stream" is used in special circumstances (such as the + // constructors of Ops) where there is no available OpKernelContext. + // "default_context" is used by OpKernelContext whenever a device does not + // supply a DeviceContext for an op in FillContextMap (e.g. when only + // using a single stream.) + // "event_mgr" is used to delay deallocation of temporary GPU buffers. + // TODO(pbar) Work out how to move this out of DeviceBase. + struct GpuDeviceInfo { + perftools::gputools::Stream* stream; + DeviceContext* default_context; + EventMgr* event_mgr; + }; + + // Does not take ownership. + void set_tensorflow_gpu_device_info(GpuDeviceInfo* g) { + gpu_device_info_ = g; + } + + const GpuDeviceInfo* tensorflow_gpu_device_info() const { + return gpu_device_info_; + } + + // Does not take ownership. + void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) { + eigen_cpu_device_ = d; + } + + // Return the Allocator implementation to use based on the allocator + // attributes requested. See allocator.h for more details. + virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) { + LOG(FATAL) << "GetAllocator() is not implemented."; + } + + const Eigen::ThreadPoolDevice* eigen_cpu_device() { + CHECK(eigen_cpu_device_ != nullptr); + return eigen_cpu_device_; + } + + // The caller owns the returned device and must free it by calling + // DisposeGpuDevice below + virtual const PerOpGpuDevice* MakeGpuDevice(DeviceContext* /*dc*/, + Allocator* /*allocator*/) { + // The OpKernelContext calls this even for devices that do not + // implement an eigen_gpu_device + return nullptr; + } + + virtual const DeviceAttributes& attributes() const { + LOG(FATAL) << "Device does not implement attributes()"; + } + + // Materializes the given TensorProto into 'tensor' stored in Device + // memory. Most devices will want to override this. + // + // TODO(vrv): We should be able to put this function into + // OpKernelContext and handle the copies from device memory via send + // and receive nodes, instead of requiring that each device handle + // the copies here as well as in copy ops. + virtual Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) { + return errors::Internal("Device does not implement MakeTensorFromProto()"); + } + + private: + Env* const env_; + CpuWorkerThreads* cpu_worker_threads_ = nullptr; + GpuDeviceInfo* gpu_device_info_ = nullptr; + Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_ diff --git a/tensorflow/core/framework/fake_input.cc b/tensorflow/core/framework/fake_input.cc new file mode 100644 index 0000000000..493c35e05f --- /dev/null +++ b/tensorflow/core/framework/fake_input.cc @@ -0,0 +1,214 @@ +#include "tensorflow/core/framework/fake_input.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace { + +class FakeInputImpl { + public: + FakeInputImpl(const OpDef* op_def, int in_index, const NodeDef* node_def, + NodeDefBuilder* builder); + void SetN(int n); + void SetDataType(DataType dt); + void SetTypeList(DataTypeSlice dts); + Status AddInputToBuilder(); + + private: + static string FakeNodeName(int in_index); + Status GetN(int* n) const; + Status GetDataType(DataType* dt) const; + void NSources(int n, DataType dt) const; + void SourceList(DataTypeSlice dts) const; + + const OpDef* const op_def_; + const OpDef::ArgDef* const arg_; + const string in_node_; + const NodeDef* const node_def_; + NodeDefBuilder* const builder_; + + bool n_specified_; + int n_; + bool dt_specified_; + DataType dt_; + bool dts_specified_; + DataTypeSlice dts_; +}; + +FakeInputImpl::FakeInputImpl(const OpDef* op_def, int in_index, + const NodeDef* node_def, NodeDefBuilder* builder) + : op_def_(op_def), + arg_(&op_def->input_arg(in_index)), + in_node_(FakeNodeName(in_index)), + node_def_(node_def), + builder_(builder), + n_specified_(false), + dt_specified_(false), + dts_specified_(false) {} + +void FakeInputImpl::SetN(int n) { + n_specified_ = true; + n_ = n; +} + +void FakeInputImpl::SetDataType(DataType dt) { + dt_specified_ = true; + dt_ = dt; +} + +void FakeInputImpl::SetTypeList(DataTypeSlice dts) { + dts_specified_ = true; + dts_ = dts; +} + +Status FakeInputImpl::AddInputToBuilder() { + if (dts_specified_) { + SourceList(dts_); + + } else if (n_specified_ || !arg_->number_attr().empty()) { + int n; + TF_RETURN_IF_ERROR(GetN(&n)); + + DataType dt; + if (n > 0) { + TF_RETURN_IF_ERROR(GetDataType(&dt)); + } else { + dt = DT_FLOAT; + } + + NSources(n, dt); + } else { + if (!dt_specified_ && !arg_->type_list_attr().empty()) { + DataTypeVector dts; + Status status = + GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts); + if (!status.ok()) { + return errors::InvalidArgument( + "Could not infer list of types for input '", arg_->name(), "': ", + status.error_message()); + } + SourceList(dts); + return Status::OK(); + } + + DataType dt; + TF_RETURN_IF_ERROR(GetDataType(&dt)); + builder_->Input(in_node_, 0, dt); + } + return Status::OK(); +} + +// static +string FakeInputImpl::FakeNodeName(int in_index) { + char c = 'a' + (in_index % 26); + return string(&c, 1); +} + +Status FakeInputImpl::GetN(int* n) const { + if (n_specified_) { + *n = n_; + } else { + Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n); + if (!status.ok()) { + return errors::InvalidArgument("Could not infer length of input '", + arg_->name(), "': ", + status.error_message()); + } + } + return Status::OK(); +} + +Status FakeInputImpl::GetDataType(DataType* dt) const { + if (dt_specified_) { + *dt = dt_; + } else if (arg_->type() != DT_INVALID) { + *dt = arg_->type(); + } else if (!arg_->type_attr().empty()) { + Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt); + if (!status.ok()) { + return errors::InvalidArgument("Could not infer type for input '", + arg_->name(), "': ", + status.error_message()); + } + } else { + return errors::InvalidArgument("No type or type_attr field in arg '", + arg_->name(), "'"); + } + return Status::OK(); +} + +void FakeInputImpl::NSources(int n, DataType dt) const { + std::vector srcs; + srcs.reserve(n); + for (int i = 0; i < n; ++i) { + srcs.emplace_back(in_node_, i, dt); + } + builder_->Input(srcs); +} + +void FakeInputImpl::SourceList(DataTypeSlice dts) const { + std::vector srcs; + srcs.reserve(dts.size()); + for (size_t i = 0; i < dts.size(); ++i) { + srcs.emplace_back(in_node_, i, dts[i]); + } + builder_->Input(srcs); +} + +} // namespace + +// Public interface ------------------------------------------------------------ + +FakeInputFunctor FakeInput() { + return [](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + return impl.AddInputToBuilder(); + }; +} + +FakeInputFunctor FakeInput(DataType dt) { + return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + impl.SetDataType(dt); + return impl.AddInputToBuilder(); + }; +} + +FakeInputFunctor FakeInput(int n) { + return [n](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + impl.SetN(n); + return impl.AddInputToBuilder(); + }; +} + +FakeInputFunctor FakeInput(int n, DataType dt) { + return [n, dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + impl.SetN(n); + impl.SetDataType(dt); + return impl.AddInputToBuilder(); + }; +} + +FakeInputFunctor FakeInput(DataTypeSlice dts) { + // Make a copy to ensure the data will still be around when the lambda is + // called. + DataTypeVector dtv(dts.begin(), dts.end()); + return [dtv](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + impl.SetTypeList(dtv); + return impl.AddInputToBuilder(); + }; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/fake_input.h b/tensorflow/core/framework/fake_input.h new file mode 100644 index 0000000000..39b38e9a59 --- /dev/null +++ b/tensorflow/core/framework/fake_input.h @@ -0,0 +1,25 @@ +#ifndef TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_ +#define TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_ + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +// These functions return values that may be passed to +// NodeDefBuilder::Input() to add an input for a test. Use them when +// you don't care about the node names/output indices providing the +// input. They also allow you to omit the input types and/or +// list length when they may be inferred. +FakeInputFunctor FakeInput(); // Infer everything +FakeInputFunctor FakeInput(DataType dt); +FakeInputFunctor FakeInput(int n); // List of length n +FakeInputFunctor FakeInput(int n, DataType dt); +FakeInputFunctor FakeInput(DataTypeSlice dts); +inline FakeInputFunctor FakeInput(std::initializer_list dts) { + return FakeInput(DataTypeSlice(dts)); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_ diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc new file mode 100644 index 0000000000..b73e1ab8a9 --- /dev/null +++ b/tensorflow/core/framework/function.cc @@ -0,0 +1,878 @@ +#include "tensorflow/core/framework/function.h" + +#include + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { + +REGISTER_OP("_Arg") + .Output("output: T") + .Attr("T: type") + .Attr("index: int >= 0") + .Doc(R"doc( +A graph node which represents an argument to a function. + +output: The argument. +index: This argument is the index-th argument of the function. +)doc"); + +REGISTER_OP("_Retval") + .Input("input: T") + .Attr("T: type") + .Attr("index: int >= 0") + .Doc(R"doc( +A graph node which represents a return value of a function. + +input: The return value. +index: This return value is the index-th return value of the function. +)doc"); + +REGISTER_OP("_ListToArray") + .Input("input: Tin") + .Output("output: N * T") + .Attr("Tin: list(type)") + .Attr("T: type") + .Attr("N: int >= 1") + .Doc(R"doc( +Converts a list of tensors to an array of tensors. +)doc"); + +REGISTER_OP("_ArrayToList") + .Input("input: N * T") + .Output("output: out_types") + .Attr("T: type") + .Attr("N: int >= 1") + .Attr("out_types: list(type)") + .Doc(R"doc( +Converts an array of tensors to a list of tensors. +)doc"); + +namespace { + +// Extracts the actual type from "attr_values" based on its definition +// "arg_def". +Status ArgNumType(const InstantiateAttrValueMap& attrs, + const OpDef::ArgDef& arg_def, int* num, DataType* dtype) { + if (!arg_def.type_list_attr().empty()) { + return errors::Unimplemented("type_list is not supported."); + } + + if (arg_def.number_attr().empty()) { + *num = 1; + } else { + const AttrValue* v = gtl::FindOrNull(attrs, arg_def.number_attr()); + if (v == nullptr) { + return errors::NotFound("type attr not found: ", arg_def.type_attr()); + } + *num = v->i(); + } + + if (arg_def.type() != DT_INVALID) { + *dtype = arg_def.type(); + } else if (arg_def.type_attr().empty()) { + *dtype = DT_INVALID; + } else { + const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_attr()); + if (v == nullptr) { + return errors::NotFound("type attr not found: ", arg_def.type_attr()); + } + *dtype = v->type(); + } + return Status::OK(); +} + +string Name(int node_index) { return strings::StrCat("n", node_index); } + +string Name(int node_index, int output_index) { + if (output_index == 0) { + return Name(node_index); + } else { + return strings::StrCat("n", node_index, ":", output_index); + } +} + +string Dep(int node_index) { return strings::StrCat("^", Name(node_index)); } + +template +void AddAttr(const string& name, const T& val, NodeDef* ndef) { + SetAttrValue(val, &((*ndef->mutable_attr())[name])); +} + +Status ValidateSignatureWithAttrs(const OpDef& sig, + const InstantiateAttrValueMap& attr_values) { + // attr_values should specify all attrs defined in fdef. + for (const auto& a : sig.attr()) { + if (attr_values.find(a.name()) == attr_values.end()) { + return errors::NotFound("Attr ", a.name(), " is not found."); + } + } + + for (const auto& p : attr_values) { + if (HasPlaceHolder(p.second)) { + return errors::InvalidArgument(p.first, + " in attr_values is still a placeholder."); + } + } + + return Status::OK(); +} + +// We build a small index for all names that can be used as a node's +// input arguments. +// +// If is_func_arg is true, the name is a function's argument. In +// this case, the produced graph def has gdef.node[nid ... nid + +// num). +// +// Otherwise, the name is a function body's node return value. In +// this case, the produced graph def has one node gdef.node[nid] and +// the node's output index [idx ... idx + num) corresponds to the +// named outputs. +// +// In all cases, "dtype" specifies the data type. +struct NameInfoItem { + bool is_func_arg; + int nid; + int idx; + int num; + DataType dtype; +}; +typedef std::unordered_map NameInfoIndex; + +Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, + const InstantiateAttrValueMap& attr_values, + NameInfoIndex* name_info, + InstantiationResult* result) { + int num; + DataType dtype; + TF_RETURN_IF_ERROR(ArgNumType(attr_values, arg_def, &num, &dtype)); + CHECK_GE(num, 1); + GraphDef* gdef = &result->gdef; + int arg_index = gdef->node_size(); + if (!name_info->insert({arg_def.name(), {true, arg_index, 0, num, dtype}}) + .second) { + return errors::InvalidArgument("Duplicated arg name."); + } + // Creates "num" nodes in the gdef. + for (int i = 0; i < num; ++i) { + DCHECK_EQ(arg_index, gdef->node_size()); + NodeDef* gnode = gdef->add_node(); + gnode->set_name(Name(arg_index)); + gnode->set_op("_Arg"); + AddAttr("T", dtype, gnode); + AddAttr("index", arg_index, gnode); + result->arg_types.push_back(dtype); + ++arg_index; + } + return Status::OK(); +} + +Status BuildNodeOutputIndex(const FunctionDef::Node& node, + const InstantiateAttrValueMap& attrs, + GetFunctionSignature get_function, + const int arg_index, NameInfoIndex* name_info) { + const OpDef* node_sig = nullptr; + TF_RETURN_IF_ERROR(get_function(node.op(), &node_sig)); + if (node_sig->output_arg_size() == 0) { + // This node produces no output. + if (node.ret_size() != 1) { + return errors::InvalidArgument("Expect one ret name."); + } + if (!name_info->insert({node.ret(0), {false, arg_index, 0, 0, DT_INVALID}}) + .second) { + return errors::InvalidArgument("Duplicated ret name."); + } + return Status::OK(); + } + + // When the signature says the last return value is of list(type), + // i.e., it's variadic, we need to consult + // attrs[last_retval.type_list_attr] to determine for the last arg + // * the actual number of outputs; + // * the actual data type of outputs. + const int num_retval = node_sig->output_arg_size(); + const OpDef::ArgDef& last_retval = node_sig->output_arg(num_retval - 1); + const bool last_retval_is_typelist = !last_retval.type_list_attr().empty(); + if (!last_retval_is_typelist && (node.ret_size() != num_retval)) { + return errors::InvalidArgument("Malformed function node (#ret)."); + } + int start = 0; + const int num_fixed_size_retval = + last_retval_is_typelist ? num_retval - 1 : num_retval; + for (int i = 0; i < num_fixed_size_retval; ++i) { + int num; + DataType dtype; + TF_RETURN_IF_ERROR( + ArgNumType(attrs, node_sig->output_arg(i), &num, &dtype)); + if (!name_info->insert({node.ret(i), {false, arg_index, start, num, dtype}}) + .second) { + return errors::InvalidArgument("Duplicated ret name."); + } + start += num; + } + if (last_retval_is_typelist) { + const AttrValue* typelist = + gtl::FindOrNull(attrs, last_retval.type_list_attr()); + if (typelist == nullptr) { + return errors::InvalidArgument("Missing attr ", + last_retval.type_list_attr(), "."); + } + if (num_fixed_size_retval + typelist->list().type_size() != + node.ret_size()) { + return errors::InvalidArgument("Wrong #ret: ", num_fixed_size_retval, " ", + typelist->list().type_size(), " ", + node.ret_size(), "."); + } + for (int i = 0; i < typelist->list().type_size(); ++i) { + if (!name_info->insert({node.ret(i), + {false, arg_index, start, 1, + typelist->list().type(i)}}) + .second) { + return errors::InvalidArgument("Duplicated ret name."); + } + ++start; + } + } + return Status::OK(); +} + +Status InstantiateNode(const FunctionDef::Node& fnode, + const InstantiateAttrValueMap& attrs, + GetFunctionSignature get_function, + const NameInfoIndex& name_info, GraphDef* gdef) { + const OpDef* fnode_sig = nullptr; + TF_CHECK_OK(get_function(fnode.op(), &fnode_sig)); + NodeDef* gnode = gdef->add_node(); + gnode->set_name(Name(gdef->node_size() - 1)); + gnode->set_op(fnode.op()); + + // Input + // + // When the signature says the last argument is of list(type), + // i.e., it's variadic, we need to consult + // attrs[last_arg.type_list_attr] to determine for the last arg + // * the number of arguments; + // * the data types of arguments. + const int num_arg = fnode_sig->input_arg_size(); + bool last_arg_is_typelist = false; + if (num_arg > 0 && + !fnode_sig->input_arg(num_arg - 1).type_list_attr().empty()) { + last_arg_is_typelist = true; + } + if (!last_arg_is_typelist && (fnode.arg_size() != num_arg)) { + return errors::InvalidArgument("arg.size != sig.arg.size."); + } + const int num_fixed_size_args = last_arg_is_typelist ? num_arg - 1 : num_arg; + for (int i = 0; i < num_fixed_size_args; ++i) { + int num; + DataType dtype; + TF_RETURN_IF_ERROR( + ArgNumType(attrs, fnode_sig->input_arg(i), &num, &dtype)); + const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.arg(i)); + if (item == nullptr) { + return errors::InvalidArgument("arg[", i, "] is not found: ", + fnode.ShortDebugString()); + } + if (num != item->num || dtype != item->dtype) { + return errors::InvalidArgument("Invalid arg(", i, ") for function arg: ", + " ", num, "/", dtype, " vs. ", item->num, + "/", item->dtype, "."); + } + for (int j = 0; j < num; ++j) { + if (item->is_func_arg) { + gnode->add_input(Name(item->nid + j)); + } else { + gnode->add_input(Name(item->nid, item->idx + j)); + } + } + } + if (last_arg_is_typelist) { + AttrValue typelist; + for (int i = num_fixed_size_args; i < fnode.arg_size(); ++i) { + const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.arg(i)); + if (item == nullptr) { + return errors::InvalidArgument("arg[", i, "] is not found."); + } + for (int j = 0; j < item->num; ++j) { + if (item->is_func_arg) { + gnode->add_input(Name(item->nid + j)); + } else { + gnode->add_input(Name(item->nid, item->idx + j)); + } + typelist.mutable_list()->add_type(item->dtype); + } + } + + // 'typelist' is inferred from the inputs' data types. + const auto& last_arg = fnode_sig->input_arg(num_arg - 1); + gnode->mutable_attr()->insert({last_arg.type_list_attr(), typelist}); + } + + // Control deps. + for (int i = 0; i < fnode.dep_size(); ++i) { + const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.dep(i)); + if (item == nullptr) { + return errors::InvalidArgument("dep[", i, "] is not found."); + } + gnode->add_input(Dep(item->nid)); + } + + // Attrs. + for (const auto& p : attrs) { + (*gnode->mutable_attr())[p.first] = p.second; + } + + return Status::OK(); +} + +Status AddReturnNode(const OpDef::ArgDef& ret_def, + const InstantiateAttrValueMap& attrs, + const NameInfoIndex& name_info, int* ret_index, + InstantiationResult* result) { + int num; + DataType dtype; + TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &num, &dtype)); + CHECK_GE(num, 1); + const NameInfoItem* item = gtl::FindOrNull(name_info, ret_def.name()); + if (item == nullptr) { + return errors::InvalidArgument("ret is not found."); + } + if (num != item->num || dtype != item->dtype) { + return errors::InvalidArgument("Invalid ret name."); + } + GraphDef* gdef = &result->gdef; + for (int i = 0; i < num; ++i) { + NodeDef* gnode = gdef->add_node(); + gnode->set_name(Name(gdef->node_size() - 1)); + gnode->set_op("_Retval"); + gnode->add_input(Name(item->nid, item->idx + i)); + AddAttr("T", dtype, gnode); + AddAttr("index", (*ret_index)++, gnode); + result->ret_types.push_back(dtype); + } + return Status::OK(); +} + +// Various helpers Print(proto) to print relevant protos to ascii. +string Print(const OpDef::ArgDef& arg) { + string out; + strings::StrAppend(&out, arg.name(), ":"); + if (arg.is_ref()) strings::StrAppend(&out, "Ref("); + if (!arg.number_attr().empty()) { + strings::StrAppend(&out, arg.number_attr(), "*"); + } + if (arg.type() != DT_INVALID) { + strings::StrAppend(&out, DataTypeString(arg.type())); + } else { + strings::StrAppend(&out, arg.type_attr()); + } + if (arg.is_ref()) strings::StrAppend(&out, ")"); + return out; +} + +string Print(const AttrValue& attr_value) { + if (attr_value.value_case() == AttrValue::kType) { + return DataTypeString(attr_value.type()); + } else if ((attr_value.value_case() == AttrValue::kList) && + (attr_value.list().type_size() > 0)) { + string ret = "{"; + for (int i = 0; i < attr_value.list().type_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i))); + } + strings::StrAppend(&ret, "}"); + return ret; + } else if (attr_value.value_case() == AttrValue::kFunc) { + if (attr_value.func().attr_size() == 0) { + return attr_value.func().name(); + } + std::vector entries; + for (auto p : attr_value.func().attr()) { + entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + } + sort(entries.begin(), entries.end()); + return strings::StrCat(attr_value.func().name(), "[", + str_util::Join(entries, ", "), "]"); + } + return SummarizeAttrValue(attr_value); +} + +string Print(const FunctionDef::Node& node) { + string out; + for (int i = 0; i < node.ret_size(); ++i) { + const auto& name = node.ret(i); + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, name); + } + strings::StrAppend(&out, " = ", node.op()); + if (node.attr_size() > 0) { + std::vector entries; + for (auto p : node.attr()) { + entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + } + sort(entries.begin(), entries.end()); + strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]"); + } + strings::StrAppend(&out, "("); + for (int i = 0; i < node.arg_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, node.arg(i)); + } + strings::StrAppend(&out, ")"); + if (node.dep_size() > 0) { + strings::StrAppend(&out, " @ "); + for (int i = 0; i < node.dep_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, node.dep(i)); + } + } + return out; +} + +string Print(const FunctionDef& fdef) { + string out; + const OpDef& sig = fdef.signature(); + strings::StrAppend(&out, "\n", sig.name()); + if (sig.attr_size() > 0) { + strings::StrAppend(&out, "["); + for (int i = 0; i < sig.attr_size(); ++i) { + const auto& a = sig.attr(i); + if (i > 0) strings::StrAppend(&out, ", "); + if (a.type() == "type") { + strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values())); + } else { + strings::StrAppend(&out, a.name(), ":", a.type()); + } + } + strings::StrAppend(&out, "]"); + } + strings::StrAppend(&out, "("); + for (int i = 0; i < sig.input_arg_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, Print(sig.input_arg(i))); + } + strings::StrAppend(&out, ") -> ("); + for (int i = 0; i < sig.output_arg_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, Print(sig.output_arg(i))); + } + strings::StrAppend(&out, ") {\n"); + for (const auto& n : fdef.node()) { + strings::StrAppend(&out, " ", Print(n), "\n"); + } + strings::StrAppend(&out, "}\n"); + return out; +} + +string Print(const NodeDef& n) { + string out; + strings::StrAppend(&out, n.name(), " = ", n.op()); + if (n.attr_size() > 0) { + std::vector entries; + for (auto& a : n.attr()) { + entries.push_back(strings::StrCat(a.first, "=", Print(a.second))); + } + sort(entries.begin(), entries.end()); + strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]"); + } + strings::StrAppend(&out, "("); + std::vector dat; + std::vector dep; + for (StringPiece s : n.input()) { + if (s.Consume("^")) { + dep.push_back(s.ToString()); + } else { + dat.push_back(s); + } + } + strings::StrAppend(&out, str_util::Join(dat, ", "), ")"); + if (!dep.empty()) { + strings::StrAppend(&out, " @ ", str_util::Join(dep, ", ")); + } + return out; +} + +string Print(const GraphDef& gdef) { + std::vector arg; + std::vector ret; + std::vector body; + for (const NodeDef& n : gdef.node()) { + if (n.op() == "_Arg") { + arg.push_back(&n); + } else if (n.op() == "_Retval") { + ret.push_back(&n); + } else { + body.push_back(&n); + } + } + auto comp = [](const NodeDef* x, const NodeDef* y) { + int xi; + TF_CHECK_OK(GetNodeAttr(*x, "index", &xi)); + int yi; + TF_CHECK_OK(GetNodeAttr(*y, "index", &yi)); + return xi < yi; + }; + sort(arg.begin(), arg.end(), comp); + sort(ret.begin(), ret.end(), comp); + string out; + strings::StrAppend(&out, "\n("); + auto get_type = [](const NodeDef& n) { + for (auto a : n.attr()) { + if (a.first == "T") { + return DataTypeString(a.second.type()); + } + } + return DataTypeString(DT_INVALID); + }; + for (size_t i = 0; i < arg.size(); ++i) { + const NodeDef* n = arg[i]; + if (i > 0) strings::StrAppend(&out, ", "); + CHECK_EQ(2, n->attr_size()); + strings::StrAppend(&out, n->name(), ":", get_type(*n)); + } + strings::StrAppend(&out, ") -> ("); + for (size_t i = 0; i < ret.size(); ++i) { + const NodeDef* n = ret[i]; + if (i > 0) strings::StrAppend(&out, ", "); + CHECK_EQ(2, n->attr_size()); + CHECK_EQ(1, n->input_size()); + strings::StrAppend(&out, n->input(0), ":", get_type(*n)); + } + strings::StrAppend(&out, ") {\n"); + for (size_t i = 0; i < body.size(); ++i) { + strings::StrAppend(&out, " ", Print(*body[i]), "\n"); + } + strings::StrAppend(&out, "}\n"); + return out; +} + +} // end namespace + +Status InstantiateFunction(const FunctionDef& fdef, + const InstantiateAttrValueMap& attr_values, + GetFunctionSignature get_function, + InstantiationResult* result) { + const OpDef& sig = fdef.signature(); + GraphDef* gdef = &result->gdef; + gdef->Clear(); + + TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values)); + + auto substitute = [&attr_values](const string& name, AttrValue* val) { + auto iter = attr_values.find(name); + if (iter == attr_values.end()) { + return false; + } else { + *val = iter->second; + return true; + } + }; + + // Makes a copy of all attrs in fdef and substitutes placeholders. + // After this step, every attr is bound to a concrete value. + std::vector node_attrs; + node_attrs.resize(fdef.node_size()); + for (int i = 0; i < fdef.node_size(); ++i) { + for (auto attr : fdef.node(i).attr()) { + if (!SubstitutePlaceholders(substitute, &attr.second)) { + return errors::InvalidArgument("Failed to bind all placeholders in ", + SummarizeAttrValue(attr.second)); + } + CHECK(node_attrs[i].insert(attr).second); + } + } + + NameInfoIndex name_info; + Status s; + for (const OpDef::ArgDef& arg_def : sig.input_arg()) { + s = BuildInputArgIndex(arg_def, attr_values, &name_info, result); + if (!s.ok()) { + errors::AppendToMessage(&s, " In ", Print(arg_def)); + return s; + } + } + for (int i = 0; i < fdef.node_size(); ++i) { + s = BuildNodeOutputIndex(fdef.node(i), node_attrs[i], get_function, + gdef->node_size() + i, &name_info); + if (!s.ok()) { + errors::AppendToMessage(&s, " In ", Print(fdef.node(i))); + return s; + } + } + + // Emits one gdef.node for each fdef.node. + for (int i = 0; i < fdef.node_size(); ++i) { + s = InstantiateNode(fdef.node(i), node_attrs[i], get_function, name_info, + gdef); + if (!s.ok()) { + errors::AppendToMessage(&s, " In ", Print(fdef.node(i))); + return s; + } + } + + // Emits nodes for the function's return values. + int ret_index = 0; + for (const OpDef::ArgDef& ret_def : sig.output_arg()) { + s = AddReturnNode(ret_def, attr_values, name_info, &ret_index, result); + if (!s.ok()) { + errors::AppendToMessage(&s, " In ", Print(ret_def)); + return s; + } + } + + return Status::OK(); +} + +string DebugString(const FunctionDef& func_def) { return Print(func_def); } + +string DebugString(const GraphDef& instantiated_func_def) { + return Print(instantiated_func_def); +} + +string DebugStringWhole(const GraphDef& gdef) { + string ret; + for (auto fdef : gdef.library().function()) { + strings::StrAppend(&ret, Print(fdef)); + } + strings::StrAppend(&ret, "\n"); + for (auto ndef : gdef.node()) { + strings::StrAppend(&ret, Print(ndef), "\n"); + } + return ret; +} + +string Canonicalize(const string& funcname, + const InstantiateAttrValueMap& attrs) { + std::vector entries; + entries.reserve(attrs.size()); + for (auto p : attrs) { + entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + } + sort(entries.begin(), entries.end()); + return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]"); +} + +FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types, + DataTypeSlice ret_types) + : arg_types_(arg_types.begin(), arg_types.end()), + ret_types_(ret_types.begin(), ret_types.end()) { + args_.resize(arg_types_.size()); + rets_.resize(ret_types_.size()); +} + +FunctionCallFrame::~FunctionCallFrame() {} + +Status FunctionCallFrame::SetArgs(gtl::ArraySlice args) { + // Input type checks. + if (args.size() != arg_types_.size()) { + return errors::InvalidArgument("Expects ", arg_types_.size(), + " arguments, but ", args.size(), + " is provided"); + } + for (size_t i = 0; i < args.size(); ++i) { + if (arg_types_[i] != args[i].dtype()) { + return errors::InvalidArgument( + "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ", + DataTypeString(args[i].dtype()), " is provided"); + } + args_[i] = args[i]; + } + return Status::OK(); +} + +Status FunctionCallFrame::GetRetvals(std::vector* rets) const { + rets->clear(); + rets->reserve(rets_.size()); + for (size_t i = 0; i < rets_.size(); ++i) { + auto item = rets_[i]; + if (item.has_val) { + rets->push_back(item.val); + } else { + return errors::Internal("Retval[", i, "] does not have value"); + } + } + return Status::OK(); +} + +Status FunctionCallFrame::GetArg(int index, Tensor* val) const { + if (index < 0 || static_cast(index) >= args_.size()) { + return errors::OutOfRange("GetArg ", index, " is not within [0, ", + args_.size(), ")"); + } + *val = args_[index]; + return Status::OK(); +} + +Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { + if (index < 0 || static_cast(index) >= rets_.size()) { + return errors::OutOfRange("SetRetval ", index, " is not within [0, ", + rets_.size(), ")"); + } + if (val.dtype() != ret_types_[index]) { + return errors::InvalidArgument( + "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]), + ", but ", DataTypeString(val.dtype()), " is provided."); + } + Retval* item = &rets_[index]; + if (!item->has_val) { + item->has_val = true; + item->val = val; + } else { + return errors::Internal("Retval[", index, "] has already been set."); + } + return Status::OK(); +} + +FunctionLibraryDefinition::FunctionLibraryDefinition( + const FunctionDefLibrary& def_lib) + : function_defs_(def_lib.function_size()) { + for (auto fdef : def_lib.function()) { + // The latter function definition wins. + function_defs_[fdef.signature().name()] = fdef; + } +} + +FunctionLibraryDefinition::~FunctionLibraryDefinition() {} + +const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const { + auto iter = function_defs_.find(name); + if (iter == function_defs_.end()) { + return nullptr; + } else { + return &iter->second; + } +} + +const OpDef* FunctionLibraryDefinition::LookUp(const string& op, + Status* status) const { + auto fdef = Find(op); + if (fdef != nullptr) { + return &(fdef->signature()); + } + return OpRegistry::Global()->LookUp(op, status); +} + +Status InstantiateFunction(const FunctionDef& fdef, + InstantiateAttrValueSlice attr_values, + GetFunctionSignature get_function, + InstantiationResult* result) { + InstantiateAttrValueMap m; + for (const auto& aval : attr_values) { + m.insert({aval.first, aval.second.proto}); + } + return InstantiateFunction(fdef, m, get_function, result); +} + +string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs) { + InstantiateAttrValueMap m; + for (const auto& aval : attrs) { + m.insert({aval.first, aval.second.proto}); + } + return Canonicalize(funcname, m); +} + +Status FunctionLibraryRuntime::Instantiate(const string& function_name, + InstantiateAttrValueSlice attrs, + Handle* handle) { + InstantiateAttrValueMap m; + for (const auto& aval : attrs) { + m.insert({aval.first, aval.second.proto}); + } + return Instantiate(function_name, m, handle); +} + +void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) { + if (val.size() >= 2 && val[0] == '$') { + proto.set_placeholder(val.data() + 1, val.size() - 1); + } else { + SetAttrValue(val, &proto); + } +} + +FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef( + const string& name, + gtl::ArraySlice> attrs) { + AttrValueWrapper ret; + ret.proto.mutable_func()->set_name(name); + for (const auto& a : attrs) { + ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto}); + } + return ret; +} + +FunctionDef::Node FunctionDefHelper::Node::ToProto() const { + FunctionDef::Node n; + for (const string& r : this->ret) { + n.add_ret(r); + } + n.set_op(this->op); + for (const string& a : arg) { + n.add_arg(a); + } + for (const auto& a : this->attr) { + n.mutable_attr()->insert({a.first, a.second.proto}); + } + for (const string& d : dep) { + n.add_dep(d); + } + return n; +} + +/* static */ +FunctionDef FunctionDefHelper::Define(const string& name, + gtl::ArraySlice arg_def, + gtl::ArraySlice ret_def, + gtl::ArraySlice attr_def, + gtl::ArraySlice node_def) { + FunctionDef fdef; + OpDefBuilder b(name); + for (const auto& a : arg_def) b.Input(a); + for (const auto& r : ret_def) b.Output(r); + for (const auto& a : attr_def) b.Attr(a); + TF_CHECK_OK(b.Finalize(fdef.mutable_signature())); + for (const auto& n : node_def) { + *(fdef.add_node()) = n.ToProto(); + } + return fdef; +} + +FunctionDef FunctionDefHelper::Define(gtl::ArraySlice arg_def, + gtl::ArraySlice ret_def, + gtl::ArraySlice attr_def, + gtl::ArraySlice node_def) { + return Define("_", arg_def, ret_def, attr_def, node_def); +} + +namespace gradient { + +typedef std::unordered_map OpGradFactory; + +OpGradFactory* GetOpGradFactory() { + static OpGradFactory* factory = new OpGradFactory; + return factory; +} + +bool RegisterOp(const string& op, Creator func) { + CHECK(GetOpGradFactory()->insert({op, func}).second) + << "Duplicated gradient for " << op; + return true; +} + +Status GetOpGradientCreator(const string& op, Creator* creator) { + auto fac = GetOpGradFactory(); + auto iter = fac->find(op); + if (iter == fac->end()) { + return errors::NotFound("No gradient defined for op: ", op); + } + *creator = iter->second; + return Status::OK(); +} + +} // end namespace gradient + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h new file mode 100644 index 0000000000..1ef93a0533 --- /dev/null +++ b/tensorflow/core/framework/function.h @@ -0,0 +1,376 @@ +#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_ +#define TENSORFLOW_FRAMEWORK_FUNCTION_H_ + +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +class CancellationManager; +class Node; +class OpKernel; + +// FunctionDefHelper::Define is a convenient helper to construct a +// FunctionDef proto. +// +// E.g., +// FunctionDef my_func = FunctionDefHelper::Define( +// "my_func_name", +// {"x:T", "y:T" /* one string per argument */}, +// {"z:T" /* one string per return value */}, +// {"T: {float, double}" /* one string per attribute */}, +// { +// {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}} +// /* one entry per function node */ +// }) +// +// NOTE: When we have a TFLang parser, we can add another helper: +// FunctionDef FunctionDefHelper::Define(const string& tf_func); +class FunctionDefHelper { + public: + // AttrValueWrapper has copy constructors for the type T so that + // it's easy to construct a simple AttrValue proto. + // + // If T is a string type (const char*, string, or StringPiece), and + // it starts with "$", we construct a AttrValue of "placeholder". + // + // E.g., + // std:: x = {"T", "$T"} + // is a named attr value placeholder. + struct AttrValueWrapper { + AttrValue proto; + + AttrValueWrapper() {} + + template + AttrValueWrapper(T val) { // NOLINT(runtime/explicit) + SetAttrValue(val, &proto); + } + + private: + void InitFromString(StringPiece val); + }; + + // Constructs an AttrValue.func given the "name" and "attrs". + static AttrValueWrapper FunctionRef( + const string& name, + gtl::ArraySlice> attrs); + static AttrValueWrapper FunctionRef(const string& name) { + return FunctionRef(name, {}); + } + + // Node is used to consturct FunctionDef.Node using initialization + // lists. E.g., + // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y + struct Node { + std::vector ret; + string op; + std::vector arg; + std::vector> attr; + std::vector dep; + + FunctionDef::Node ToProto() const; + }; + + static FunctionDef Define(const string& function_name, + gtl::ArraySlice arg_def, + gtl::ArraySlice ret_def, + gtl::ArraySlice attr_def, + gtl::ArraySlice node_def); + + // Defines an anonymous function. I.e., its name is not relevant. + static FunctionDef Define(gtl::ArraySlice arg_def, + gtl::ArraySlice ret_def, + gtl::ArraySlice attr_def, + gtl::ArraySlice node_def); + + // Helpers to construct a constant scalar. + template + static Node Const(const string& name, const T& val) { + Node n = {{name}, "Const"}; + const DataType dtype = DataTypeToEnum::value; + n.attr.push_back({"dtype", dtype}); + Tensor t(dtype, TensorShape({})); + t.scalar()() = val; + n.attr.push_back({"value", t}); + return n; + } + + template + static Node Const(const string& name, gtl::ArraySlice vals) { + Node n = {{name}, "Const"}; + const DataType dtype = DataTypeToEnum::value; + n.attr.push_back({"dtype", dtype}); + int64 num = vals.size(); + Tensor t(dtype, TensorShape({num})); + for (int i = 0; i < vals.size(); ++i) { + t.flat()(i) = vals[i]; + } + n.attr.push_back({"value", t}); + return n; + } +}; + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) { + InitFromString(val); +} + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper( + const string& val) { + InitFromString(val); +} + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { + InitFromString(val); +} + +// Instantiate a function. +// +// "fdef" encodes a TF function with some attrs in fdef.signature.attr +// containing placeholders. InstantiateFunction binds these +// placeholders and produces an instantiated function encoded in +// "result.gdef". The value to substitute a placeholder is given by +// "attr_values", which is a map from a placeholder name to an attr +// value. +// +// InstatiateFunction calls "get_function" to find signatures of other +// functions and primitive ops. + +// Placeholders in "fdef" is substitued based on "attr_values" here. +typedef ::tensorflow::protobuf::Map InstantiateAttrValueMap; +typedef gtl::ArraySlice> + InstantiateAttrValueSlice; + +// GetFunctionSignature(func name, opdef) returns OK if the func name is found +// and opdef is filled with a pointer to the corresponding signature +// (a OpDef proto). Otherwise, returns an error. +typedef std::function + GetFunctionSignature; + +struct InstantiationResult { + DataTypeVector arg_types; + DataTypeVector ret_types; + GraphDef gdef; +}; +Status InstantiateFunction(const FunctionDef& fdef, + const InstantiateAttrValueMap& attr_values, + GetFunctionSignature get_function, + InstantiationResult* result); +Status InstantiateFunction(const FunctionDef& fdef, + InstantiateAttrValueSlice attr_values, + GetFunctionSignature get_function, + InstantiationResult* result); + +// Returns a debug string for a function definition. +// +// The returned text is multiple-line. It is intended to be +// human-readable rather than being friendly to parsers. It is _NOT_ +// intended to be the canonical string representation of "func_def". +// Particularly, it may not include all information presented in +// "func_def" (e.g., comments, description of the function arguments, +// etc.) +string DebugString(const FunctionDef& func_def); +string DebugString(const GraphDef& instantiated_func_def); + +// Returns a debug string for a top level graph (the main program and +// its supporting functions defined in its library). +string DebugStringWhole(const GraphDef& gdef); + +// Returns a canonicalized string for the instantiation of the +// function of the given "name" and attributes "attrs". +// +// The returned string is guaranteed to be stable within one address +// space. But it may be change as the implementation +// evolves. Therefore, it should not be persisted or compared across +// address spaces. +string Canonicalize(const string& funcname, + const InstantiateAttrValueMap& attrs); +string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs); + +// Represents a function call frame. I.e., the data structure used to +// pass arguments to a function and retrieve its results. +// +// Runtime must arrange accesses to one FunctionCallFrame s.t. +// 1. SetArgs() happens before any GetArg(); +// 2. GetRetvals happens after all SetRetval(); +class FunctionCallFrame { + public: + FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types); + ~FunctionCallFrame(); + + // Caller methods. + Status SetArgs(gtl::ArraySlice args); + Status GetRetvals(std::vector* rets) const; + + // Callee methods. + Status GetArg(int index, Tensor* val) const; + Status SetRetval(int index, const Tensor& val); + + private: + DataTypeVector arg_types_; + DataTypeVector ret_types_; + gtl::InlinedVector args_; + struct Retval { + bool has_val = false; + Tensor val; + }; + gtl::InlinedVector rets_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame); +}; + +// Helper to maintain a map between function names in a given +// FunctionDefLibrary and function definitions. +class FunctionLibraryDefinition : public OpRegistryInterface { + public: + explicit FunctionLibraryDefinition(const FunctionDefLibrary& lib_def); + ~FunctionLibraryDefinition() override; + + // Returns nullptr if "func" is not defined in "lib_def". Otherwise, + // returns its definition proto. + const FunctionDef* Find(const string& func) const; + + // OpRegistryInterface method. Useful for constructing a Graph. + // + // If "op" is defined in the library, returns its signature. + // Otherwise, assume "op" is a primitive op and returns its op + // signature. + const OpDef* LookUp(const string& op, Status* status) const override; + + private: + std::unordered_map function_defs_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryDefinition); +}; + +// Forward declare. Defined in common_runtime/function.h +struct FunctionBody; + +class FunctionLibraryRuntime { + public: + virtual ~FunctionLibraryRuntime() {} + + // Instantiate a function with the given "attrs". + // + // Returns OK and fills in "handle" if the instantiation succeeds. + // Otherwise returns an error and "handle" is undefined. + typedef uint64 Handle; + virtual Status Instantiate(const string& function_name, + const InstantiateAttrValueMap& attrs, + Handle* handle) = 0; + Status Instantiate(const string& function_name, + InstantiateAttrValueSlice attrs, Handle* handle); + + // Returns the function body for the instantiated function given its + // handle 'h'. Returns nullptr if "h" is not found. + // + // *this keeps the ownership of the returned object, which remains alive + // as long as *this. + virtual const FunctionBody* GetFunctionBody(Handle h) = 0; + + // Asynchronously invokes the instantiated function identified by + // "handle". + // + // If function execution succeeds, "done" is called with OK and + // "*rets" is filled with the function's return values. Otheriwse, + // "done" is called with an error status. + // + // Does not take ownership of "rets". + struct Options { + CancellationManager* cancellation_manager = nullptr; + }; + typedef std::function DoneCallback; + virtual void Run(const Options& opts, Handle handle, + gtl::ArraySlice args, std::vector* rets, + DoneCallback done) = 0; + + // Creates a "kernel" for the given node def "ndef". + // + // If succeeds, returns OK and the caller takes the ownership of the + // returned "*kernel". Otherwise, returns an error. + virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0; + + // Return true iff 'function_name' is the name of a defined function. + virtual bool IsDefined(const string& function_name) = 0; +}; + +// To register a gradient function for a builtin op, one should use +// REGISTER_OP_GRADIENT(, ); +// +// Typically, the c++ grad factory is a plan function that can be +// converted into ::tensorflow::gradient::Creator, which is +// std::function. +// +// A ::tensorflow::gradient::Creator should populate in FunctionDef* with a +// definition of a brain function which computate the gradient for the +// when the is instantiated with the given attrs. +// +// E.g., +// +// Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { +// bool transpose_a; +// TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a)); +// bool transpose_b; +// TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b)); +// DataType dtype; +// TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype)); +// if (!transpose_a && !transpose_b) { +// *g = FunctionDefHelper::Define( +// "MatMulGrad", +// {"x:T ", "y:T", "dz:T"}, // Inputs to this function +// {"dx:T", "dy:T"}, // Outputs from this function +// {"T: {float, double}"}, // Attributes needed by this function +// { +// {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}}, +// {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}}, +// {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}}, +// {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}}, +// }); +// } else { +// ... ... +// } +// return Status::OK(); +// } +// +// NOTE: $T is substituted with the type variable "T" when the +// gradient function MatMul is instantiated. +// +// TODO(zhifengc): Better documentation somewhere. + +// Macros to define a gradient function factory for a primitive +// operation. +#define REGISTER_OP_GRADIENT(name, fn) \ + REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn) + +#define REGISTER_OP_NO_GRADIENT(name) \ + REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr) + +#define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \ + REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) + +#define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \ + static bool unused_grad_##ctr = ::tensorflow::gradient::RegisterOp(name, fn) + +namespace gradient { +// Register a gradient creator for the "op". +typedef std::function Creator; +bool RegisterOp(const string& op, Creator func); + +// Returns OK the gradient creator for the "op" is found (may be +// nullptr if REGISTER_OP_NO_GRADIENT is used. +Status GetOpGradientCreator(const string& op, Creator* creator); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_ diff --git a/tensorflow/core/framework/function.proto b/tensorflow/core/framework/function.proto new file mode 100644 index 0000000000..4b8a26947c --- /dev/null +++ b/tensorflow/core/framework/function.proto @@ -0,0 +1,68 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/attr_value.proto"; +import "tensorflow/core/framework/op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// +// TODO(zhifengc): +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // The body of the function. + repeated Node node = 2; // function.node.ret[*] are unique. + + // A node is a multi-value assignment: + // (ret[0], ret[1], ...) = func(arg[0], arg[1], ...) + // + // By convention, "func" is resolved by consulting with a user-defined + // library first. If not resolved, "func" is assumed to be a builtin op. + message Node { + // This node produces multiple outputs. They are named ret[0], + // ret[1], ..., etc. + // + // REQUIRES: function.node.ret[*] are unique across all nodes. + // REQUIRES: ret.size == func/op def's number of output args. + repeated string ret = 1; + + // The op/function name. + string op = 2; + + // Arguments passed to this func/op. + // + // arg[i] must be either one of + // function.signature.input_args[*].name or one of + // function.node[*].ret[*]. + // + // REQUIRES: arg.size == func/op def's number of input args. + repeated string arg = 3; + + // Control dependencies. + // + // dep[i] must be one of function.node[*].ret[*] or one of + // function.signature.input_args[*].name. + repeated string dep = 4; + + // Attrs. + // + // 'attr' maps names defined by 'func's attr defs to attr values. + // attr values may have placeholders which are substituted + // recursively by concrete values when this node is instantiated. + // These placeholdes must name an attr listed in the FunctionDef's + // signature. + map attr = 5; + } +} diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc new file mode 100644 index 0000000000..c9483fad18 --- /dev/null +++ b/tensorflow/core/framework/function_test.cc @@ -0,0 +1,634 @@ +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/port.h" +#include + +namespace tensorflow { + +typedef FunctionDefHelper FDH; + +Status GetOpSig(const string& op, const OpDef** sig) { + Status s; + *sig = OpRegistry::Global()->LookUp(op, &s); + return s; +} + +REGISTER_OP("One") + .Output("y: T") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Returns a tensor with a single element (1) of type T. + +y: A scalar in type T. + +)doc"); + +static InstantiateAttrValueMap kNoAttrs; + +TEST(TFunc, SquarePlusOne) { + RequireDefaultOps(); + auto fdef = FDH::Define( + // Name + "SquarePlusOne", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attrs + {"T: {float, double, int32, int64}"}, + // Nodes + {// a = Square(x) + {{"a"}, "Square", {"x"}, {{"T", "$T"}}}, + // o = One() + // NOTE: We can also have a Cast(x) instead. + {{"o"}, "One", {}, {{"T", "$T"}}}, + // y = Add(a, o) + {{"y"}, "Add", {"a", "o"}, {{"T", "$T"}}}}); + + const char* e = R"P( +SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) { + a = Square[T=$T](x) + o = One[T=$T]() + y = Add[T=$T](a, o) +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + // Instantiate one with T=float + InstantiationResult result; + TF_CHECK_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result)); + const char* e2 = R"P( +(n0:float) -> (n3:float) { + n1 = Square[T=float](n0) + n2 = One[T=float]() + n3 = Add[T=float](n1, n2) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.gdef), e2); +} + +// NOTE: This is the simplest Map op. It takes a f:T->U. +REGISTER_OP("Map") + .Input("x: N * T") + .Output("y: N * U") + .Attr("T: type") + .Attr("U: type") + .Attr("N: int >= 1") + // .Attr("func: func_name_with_attr") + .Doc(R"doc( +Applies the 'func' on every input. I.e., + +y[i] = func<...>(x[i]) + +x: N tensors, each of type T; +y: N tensors, each of type U; + +)doc"); + +TEST(TFunc, AddSquared) { + auto fdef = FDH::Define( + // Name + "AddSquared", + // Args + {"x: N*T"}, + // Return values + {"y: T"}, + // Attrs + {"N:int", "T:{float, double, int32, int64}"}, + // Nodes + {// a = Map,T=$T,U=$T,N=$N>(x) + {{"a"}, + "Map", + {"x"}, + {{"func", FDH::FunctionRef("Square", {{"T", "$T"}})}, + {"T", "$T"}, + {"U", "$T"}, + {"N", "$N"}}}, + // y = AddN(a) + {{"y"}, "AddN", {"a"}, {{"N", "$N"}, {"T", "$T"}}}}); + + const char* e = R"P( +AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) { + a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x) + y = AddN[N=$N, T=$T](a) +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + // Instantiate one with T=float + InstantiationResult result; + TF_CHECK_OK(InstantiateFunction(fdef, {{"N", 3}, {"T", DT_FLOAT}}, GetOpSig, + &result)); + const char* e2 = R"P( +(n0:float, n1:float, n2:float) -> (n4:float) { + n3 = Map[N=3, T=float, U=float, func=Square[T=float]](n0, n1, n2) + n4 = AddN[N=3, T=float](n3, n3:1, n3:2) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.gdef), e2); +} + +TEST(TFunc, ControlDeps) { + auto fdef = FDH::Define( + // Name + "ControlDeps", + // Args + {"x: float"}, + // Return values + {}, + // Attrs + {}, + // Nodes + { + {{"a"}, "One", {}, {{"T", DT_FLOAT}}, {"x"}}, + {{"u"}, "NoOp", {}, {}, {"a"}}, + {{"b"}, "One", {}, {{"T", DT_FLOAT}}, {"u"}}, + {{"v"}, "NoOp", {}, {}, {"b"}}, + {{"c"}, "One", {}, {{"T", DT_FLOAT}}, {"a", "v"}}, + }); + const char* e = R"P( +ControlDeps(x:float) -> () { + a = One[T=float]() @ x + u = NoOp() @ a + b = One[T=float]() @ u + v = NoOp() @ b + c = One[T=float]() @ a, v +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + InstantiationResult result; + TF_CHECK_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + const char* e2 = R"P( +(n0:float) -> () { + n1 = One[T=float]() @ n0 + n2 = NoOp() @ n1 + n3 = One[T=float]() @ n2 + n4 = NoOp() @ n3 + n5 = One[T=float]() @ n1, n4 +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({})); + EXPECT_EQ(DebugString(result.gdef), e2); +} + +TEST(TFunc, XTimesTwo) { + auto expect = R"P( +XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { + two = Const[dtype=int64, value=Tensor]() + scale = Cast[DstT=$T, SrcT=int64](two) + y = Mul[T=$T](x, scale) +} +)P"; + EXPECT_EQ(expect, DebugString(test::function::XTimesTwo())); +} + +TEST(TFunc, WXPlusB) { + auto expect = R"P( +WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) { + mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x) + y = Add[T=$T](mm, b) +} +)P"; + EXPECT_EQ(expect, DebugString(test::function::WXPlusB())); +} + +TEST(TFunc, Body_TypeList) { + const Tensor kZero = test::AsScalar(0); + auto fdef = FDH::Define( + // Name + "Test", + // Args + {"i:float"}, + // Return values + {"o:float"}, + // Attrs + {}, + // Nodes + {{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}}, + {{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}}, + {{"a", "b", "c", "d"}, + "_ArrayToList", + {"s"}, + {{"N", 4}, + {"T", DT_FLOAT}, + {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}}, + {{"l"}, "Mul", {"a", "b"}, {{"T", DT_FLOAT}}}, + {{"r"}, "Mul", {"c", "d"}, {{"T", DT_FLOAT}}}, + {{"x"}, "_ListToArray", {"l", "r"}, {{"N", 2}, {"T", DT_FLOAT}}}, + {{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}}); + + const char* e = R"P( +Test(i:float) -> (o:float) { + zero = Const[dtype=int32, value=Tensor]() + s = Split[T=float, num_split=4](zero, i) + a, b, c, d = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](s) + l = Mul[T=float](a, b) + r = Mul[T=float](c, d) + x = _ListToArray[N=2, T=float](l, r) + o = AddN[N=2, T=float](x) +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + InstantiationResult result; + TF_CHECK_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + const char* e2 = R"P( +(n0:float) -> (n7:float) { + n1 = Const[dtype=int32, value=Tensor]() + n2 = Split[T=float, num_split=4](n1, n0) + n3 = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](n2, n2:1, n2:2, n2:3) + n4 = Mul[T=float](n3, n3:1) + n5 = Mul[T=float](n3:2, n3:3) + n6 = _ListToArray[N=2, T=float, Tin={float, float}](n4, n5) + n7 = AddN[N=2, T=float](n6, n6:1) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.gdef), e2); +} + +REGISTER_OP("Cond") + .Input("input: Tin") + .Output("output: out_types") + .Attr("Tin: list(type)") + .Attr("out_types: list(type)") + .Attr("cond: func") + .Attr("then_branch: func") + .Attr("else_branch: func") + .Doc(R"doc( +output = Cond(input) ? then_branch(input) : else_branch(input) + +cond: A function takes 'input' and returns a scalar. +then_branch: A funcion takes 'input' and returns 'output'. +else_branch: A funcion takes 'input' and returns 'output'. +)doc"); + +TEST(TFunc, Body_Array_List_Converter) { + auto fdef = FDH::Define( + // Name + "MySelect", + // Args + {"x:float"}, + // Return values + {"z:float"}, + // Attrs + {}, + // Nodes + { + {{"y"}, + "Cond", + {"x"}, + {{"Tin", DataTypeSlice{DT_FLOAT}}, + {"out_types", DataTypeSlice{DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond")}, + {"then_branch", FDH::FunctionRef("MyThen")}, + {"else_branch", FDH::FunctionRef("MyElse")}}}, + {{"z"}, + "Cond", + {"y", "y"}, + {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"out_types", DataTypeSlice{DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond2")}, + {"then_branch", FDH::FunctionRef("MyThen2")}, + {"else_branch", FDH::FunctionRef("MyElse2")}}}, + }); + + const char* e = R"P( +MySelect(x:float) -> (z:float) { + y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x) + z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y) +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + InstantiationResult result; + TF_CHECK_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + const char* e2 = R"P( +(n0:float) -> (n2:float) { + n1 = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](n0) + n2 = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](n1, n1) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.gdef), e2); +} + +static void HasError(const Status& s, const string& substr) { + EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) + << s << ", expected substring " << substr; +} + +TEST(InstantiateErrors, Not_Sufficient_Attrs) { + auto fdef = + FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, {{"U", DT_FLOAT}}, GetOpSig, &result), + "T is not found"); +} + +TEST(InstantiateErrors, AttrValue_Value_Placeholder) { + auto fdef = + FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, {{"T", "$bad"}}, GetOpSig, &result), + "T in attr_values is still a placeholder"); +} + +TEST(InstantiateErrors, Unbounded_Attr) { + auto fdef = FDH::Define("test", {}, {}, {"T:{float, double, int32, int64}"}, + { + {{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result), + "Failed to bind all placeholders"); +} + +TEST(InstantiateErrors, DupArgs) { + auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Duplicated arg name"); +} + +TEST(InstantiateErrors, Dup_Arg_Node_Name) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Duplicated ret name"); +} + +TEST(InstantiateErrors, Dup_Node_Names) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y"}, "One", {}, {{"T", DT_FLOAT}}}, + {{"y"}, "One", {}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Duplicated ret name"); +} + +TEST(InstantiateErrors, Node_Signature_Mismatch_NoOp) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y", "z"}, "NoOp", {}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Expect one ret name"); +} + +TEST(InstantiateErrors, Node_Signature_Mismatch) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y", "z"}, "One", {}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Malformed function node (#ret)"); +} + +TEST(InstantiateErrors, Node_Arg_Notfound) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "arg[1] is not found"); +} + +TEST(InstantiateErrors, Node_Arg_Mismatch) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Invalid arg(0) for function arg"); +} + +TEST(InstantiateErrors, Node_Arg_ControlMissing) { + auto fdef = + FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "dep[0] is not found"); +} + +TEST(InstantiateErrors, FuncRet_Missing) { + auto fdef = FDH::Define("test", {}, {"y: float"}, {}, + { + {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "ret is not found"); +} + +TEST(InstantiateErrors, FuncRet_Mismatch) { + auto fdef = FDH::Define("test", {}, {"y: float"}, {}, + { + {{"y"}, "One", {}, {{"T", DT_DOUBLE}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Invalid ret name.\n\t In y"); +} + +TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) { + auto fdef = FDH::Define( + // Name + "MySelect", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attrs + {}, + // Nodes + { + {{"y"}, + "Cond", + {"x", "x"}, + {{"tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond2")}, + {"then_branch", FDH::FunctionRef("MyThen2")}, + {"else_branch", FDH::FunctionRef("MyElse2")}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Missing attr out_types"); +} + +TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) { + auto fdef = FDH::Define( + // Name + "MySelect", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attrs + {}, + // Nodes + { + {{"y"}, + "Cond", + {"x", "x"}, + {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond2")}, + {"then_branch", FDH::FunctionRef("MyThen2")}, + {"else_branch", FDH::FunctionRef("MyElse2")}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Wrong #ret: 0 2 1"); +} + +TEST(InstantiateErrors, TypeList_Missing_Arg) { + auto fdef = FDH::Define( + // Name + "MySelect", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attrs + {}, + // Nodes + { + {{"y"}, + "Cond", + {"x", "unknown"}, + {{"out_types", DataTypeSlice{DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond2")}, + {"then_branch", FDH::FunctionRef("MyThen2")}, + {"else_branch", FDH::FunctionRef("MyElse2")}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "arg[1] is not found"); +} + +TEST(FunctionCallFrame, Void_Void) { + FunctionCallFrame frame({}, {}); + EXPECT_OK(frame.SetArgs({})); + auto a = test::AsTensor({100}); + HasError(frame.SetArgs({a}), "Invalid argument"); + Tensor v; + HasError(frame.GetArg(0, &v), "Out of range"); + HasError(frame.SetRetval(0, v), "Out of range"); + std::vector rets; + EXPECT_OK(frame.GetRetvals(&rets)); + EXPECT_EQ(rets.size(), 0); +} + +TEST(FunctionCallFrame, Float_Float_Float) { + FunctionCallFrame frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT}); + HasError(frame.SetArgs({}), "Invalid argument: Expects 2 arguments"); + auto a = test::AsTensor({100}); + auto b = test::AsTensor({200}); + auto c = test::AsTensor({300}); + HasError(frame.SetArgs({a, c}), + "Invalid argument: Expects arg[1] to be float"); + EXPECT_OK(frame.SetArgs({a, b})); + + Tensor v; + HasError(frame.GetArg(-1, &v), "Out of range"); + HasError(frame.GetArg(2, &v), "Out of range"); + EXPECT_OK(frame.GetArg(0, &v)); + test::ExpectTensorEqual(a, v); + EXPECT_OK(frame.GetArg(1, &v)); + test::ExpectTensorEqual(b, v); + + v = test::AsTensor({-100}); + HasError(frame.SetRetval(-1, v), "Out of range"); + HasError(frame.SetRetval(1, v), "Out of range"); + HasError(frame.SetRetval(0, test::AsTensor({-100})), + "Invalid argument: Expects ret[0] to be float"); + + std::vector rets; + HasError(frame.GetRetvals(&rets), "does not have value"); + EXPECT_OK(frame.SetRetval(0, v)); + HasError(frame.SetRetval(0, v), "has already been set"); + + EXPECT_OK(frame.GetRetvals(&rets)); + EXPECT_EQ(rets.size(), 1); + test::ExpectTensorEqual(rets[0], v); +} + +TEST(Canonicalize, Basic) { + EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT}, + {"transpose_a", false}, + {"transpose_b", false}}), + "MatMul[T=float,transpose_a=false,transpose_b=false]"); + EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT}, + {"transpose_b", false}, + {"transpose_a", false}}), + "MatMul[T=float,transpose_a=false,transpose_b=false]"); + EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_DOUBLE}, + {"transpose_b", true}, + {"transpose_a", false}}), + "MatMul[T=double,transpose_a=false,transpose_b=true]"); +} + +TEST(FunctionLibraryDefinitionTest, Find) { + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + FunctionLibraryDefinition lib_def(proto); + + EXPECT_EQ(lib_def.Find("XTimes16"), nullptr); + + auto expect = R"P( +XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { + two = Const[dtype=int64, value=Tensor]() + scale = Cast[DstT=$T, SrcT=int64](two) + y = Mul[T=$T](x, scale) +} +)P"; + auto found = lib_def.Find("XTimesTwo"); + ASSERT_NE(found, nullptr); + EXPECT_EQ(expect, DebugString(*found)); +} + +TEST(FunctionLibraryDefinitionTest, LookUp) { + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + FunctionLibraryDefinition lib_def(proto); + + Status s; + EXPECT_EQ(lib_def.LookUp("XTimes16", &s), nullptr); + + auto found = lib_def.LookUp("XTimesTwo", &s); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->DebugString(), + test::function::XTimesTwo().signature().DebugString()); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc new file mode 100644 index 0000000000..5ead947076 --- /dev/null +++ b/tensorflow/core/framework/function_testlib.cc @@ -0,0 +1,146 @@ +#include "tensorflow/core/framework/function_testlib.h" + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/tensor_testutil.h" + +namespace tensorflow { +namespace test { +namespace function { + +typedef FunctionDefHelper FDH; + +GraphDef GDef(gtl::ArraySlice nodes, + gtl::ArraySlice funcs) { + GraphDef g; + for (auto n : nodes) { + *(g.add_node()) = n; + } + auto lib = g.mutable_library(); + for (auto f : funcs) { + *(lib->add_function()) = f; + } + return g; +} + +// Helper to construct a NodeDef. +NodeDef NDef(const string& name, const string& op, + gtl::ArraySlice inputs, + gtl::ArraySlice> attrs, + const string& device) { + NodeDef n; + n.set_name(name); + n.set_op(op); + for (auto in : inputs) n.add_input(in); + n.set_device(device); + for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto}); + return n; +} + +FunctionDef NonZero() { + return FDH::Define( + // Name + "NonZero", + // Args + {"x:T"}, + // Return values + {"y:T"}, + // Attr def + {"T:{float, double, int32, int64, string}"}, + // Nodes + { + {{"y"}, "Identity", {"x"}, {{"T", "$T"}}}, + }); +} + +FunctionDef XTimesTwo() { + const Tensor kTwo = test::AsScalar(2); + return FDH::Define( + // Name + "XTimesTwo", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, + {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}}, + }); +} + +FunctionDef XTimesFour() { + return FDH::Define( + // Name + "XTimesFour", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}}, + {{"y"}, "XTimesTwo", {"x2"}, {{"T", "$T"}}}, + }); +} + +FunctionDef XTimes16() { + return FDH::Define( + // Name + "XTimes16", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}}, + {{"y"}, "XTimesFour", {"x4"}, {{"T", "$T"}}}, + }); +} + +FunctionDef WXPlusB() { + return FDH::Define( + // Name + "WXPlusB", + // Args + {"w: T", "x: T", "b: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double}"}, + // Nodes + {{{"mm"}, + "MatMul", + {"w", "x"}, + {{"T", "$T"}, + {"transpose_a", false}, + {"transpose_b", false}, + {"_kernel", "eigen"}}}, + {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}}); +} + +FunctionDef Swap() { + return FDH::Define( + // Name + "Swap", + // Args + {"i0: T", "i1: T"}, + // Return values + {"o0: T", "o1: T"}, + // Attr def + {"T: {float, double}"}, + // Nodes + {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}}, + {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}}); +} + +} // end namespace function +} // end namespace test +} // end namespace tensorflow diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h new file mode 100644 index 0000000000..ed0446ea85 --- /dev/null +++ b/tensorflow/core/framework/function_testlib.h @@ -0,0 +1,53 @@ +#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_ +#define TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_ + +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace test { +namespace function { + +// Helper to construct a NodeDef. +NodeDef NDef( + const string& name, const string& op, gtl::ArraySlice inputs, + gtl::ArraySlice> + attrs = {}, + const string& device = ""); + +// Helper to construct a GraphDef proto. +GraphDef GDef(gtl::ArraySlice nodes, + gtl::ArraySlice funcs = {}); + +// For testing convenience, we provide a few simple functions that can +// be easily executed and tested. + +// x:T -> x * 2. +FunctionDef XTimesTwo(); + +// x:T -> (x * 2) * 2. +FunctionDef XTimesFour(); + +// x:T -> ((x * 2) * 2) * 2. +FunctionDef XTimes16(); + +// w:T, x:T, b:T -> MatMul(w, x) + b +FunctionDef WXPlusB(); + +// x:T -> x:T, T is a type which we automatically converts to a bool. +FunctionDef NonZero(); + +// x:T, y:T -> y:T, x:T +FunctionDef Swap(); + +} // end namespace function +} // end namespace test +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_ diff --git a/tensorflow/core/framework/graph.proto b/tensorflow/core/framework/graph.proto new file mode 100644 index 0000000000..a9bc07e88c --- /dev/null +++ b/tensorflow/core/framework/graph.proto @@ -0,0 +1,103 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/attr_value.proto"; +import "tensorflow/core/framework/function.proto"; + +// Represents the graph of operations +// TODO(sanjay): Also want to put the following somewhere: +// * random_seed +// * replicas: Do we stamp them out in python itself? +// * where to load parameters +// * optimizer info? does it go with the parameter layers/ops? +message GraphDef { + repeated NodeDef node = 1; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", {...}} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= COLOCATED_NODE | PARTIAL_SPEC + // + // COLOCATED_NODE ::= "@" NODE_NAME // See NodeDef.name above. + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "@other/node" (colocate with "other/node") + // * "/job:worker/replica:0/task:1/gpu:3" (full specification) + // * "/job:worker/gpu:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // TODO(josh11b): Add some examples here showing best practices. + map attr = 5; +}; diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc new file mode 100644 index 0000000000..1e0d280126 --- /dev/null +++ b/tensorflow/core/framework/graph_def_util.cc @@ -0,0 +1,25 @@ +#include "tensorflow/core/framework/graph_def_util.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +string SummarizeGraphDef(const GraphDef& graph_def) { + string ret; + for (const NodeDef& node : graph_def.node()) { + strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n"); + } + return ret; +} + +Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) { + for (const NodeDef& node : graph_def.node()) { + TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node)); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h new file mode 100644 index 0000000000..7a2ec9c7a7 --- /dev/null +++ b/tensorflow/core/framework/graph_def_util.h @@ -0,0 +1,29 @@ +#ifndef TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Produce a human-readable version of a GraphDef that is more concise +// than a text-format proto. +string SummarizeGraphDef(const GraphDef& graph_def); + +// Validates the syntax of a GraphDef provided externally. +// +// The following is an EBNF-style syntax for GraphDef objects. Note that +// Node objects are actually specified as tensorflow::NodeDef protocol buffers, +// which contain many other fields that are not (currently) validated. +// +// Graph = Node * +// Node = NodeName, Inputs +// Inputs = ( DataInput * ), ( ControlInput * ) +// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ? +// ControlInput = "^", NodeName +// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] * +Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_ diff --git a/tensorflow/core/framework/kernel_def.proto b/tensorflow/core/framework/kernel_def.proto new file mode 100644 index 0000000000..db7856a156 --- /dev/null +++ b/tensorflow/core/framework/kernel_def.proto @@ -0,0 +1,33 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/attr_value.proto"; + +message KernelDef { + // Must match the name of an Op. + string op = 1; + + // Type of device this kernel runs on. + string device_type = 2; + + message AttrConstraint { + // Name of an attr from the Op. + string name = 1; + + // A list of values that this kernel supports for this attr. + // Like OpDef.AttrDef.allowed_values, except for kernels instead of Ops. + AttrValue allowed_values = 2; + } + repeated AttrConstraint constraint = 3; + + // Names of the Op's input_/output_args that reside in host memory + // instead of device memory. + repeated string host_memory_arg = 4; + + // This allows experimental kernels to be registered for an op that + // won't be used unless the user specifies a "_kernel" attr with + // value matching this. + string label = 5; +} diff --git a/tensorflow/core/framework/kernel_def_builder.cc b/tensorflow/core/framework/kernel_def_builder.cc new file mode 100644 index 0000000000..8fba883a16 --- /dev/null +++ b/tensorflow/core/framework/kernel_def_builder.cc @@ -0,0 +1,47 @@ +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { + +KernelDefBuilder::KernelDefBuilder(const char* op_name) { + kernel_def_ = new KernelDef; + kernel_def_->set_op(op_name); +} + +KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) { + kernel_def_->set_device_type(device_type); + return *this; +} + +KernelDefBuilder& KernelDefBuilder::TypeConstraint( + const char* attr_name, gtl::ArraySlice allowed) { + auto* constraint = kernel_def_->add_constraint(); + constraint->set_name(attr_name); + auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); + for (DataType dt : allowed) { + allowed_values->add_type(dt); + } + return *this; +} + +KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name, + DataType allowed) { + auto* constraint = kernel_def_->add_constraint(); + constraint->set_name(attr_name); + constraint->mutable_allowed_values()->mutable_list()->add_type(allowed); + return *this; +} + +KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) { + kernel_def_->add_host_memory_arg(arg_name); + return *this; +} + +KernelDefBuilder& KernelDefBuilder::Label(const char* label) { + CHECK_EQ(kernel_def_->label(), "") + << "Trying to set a kernel's label a second time: '" << label + << "' in: " << kernel_def_->ShortDebugString(); + kernel_def_->set_label(label); + return *this; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/kernel_def_builder.h b/tensorflow/core/framework/kernel_def_builder.h new file mode 100644 index 0000000000..0c14d1e006 --- /dev/null +++ b/tensorflow/core/framework/kernel_def_builder.h @@ -0,0 +1,77 @@ +#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_ +#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_ + +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Builder class passed to the REGISTER_KERNEL_BUILDER() macro. +class KernelDefBuilder { + public: + // Starts with just the name field set. + // Caller MUST call Build() and take ownership of the result. + explicit KernelDefBuilder(const char* op_name); + + ~KernelDefBuilder() { + DCHECK(kernel_def_ == nullptr) << "Did not call Build()"; + } + + // Required: specify the type of device this kernel supports. + // Returns *this. + KernelDefBuilder& Device(const char* device_type); + // KernelDefBuilder& Device(DeviceType device_type); + + // Specify that this kernel supports a limited set of values for a + // particular type or list(type) attr (a further restriction than + // what the Op allows). + // Returns *this. + KernelDefBuilder& TypeConstraint(const char* attr_name, + gtl::ArraySlice allowed); + + // Like TypeConstraint but supports just a single type. + KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed); + + // Like TypeConstraint, but (a) gets the type from a template parameter + // and (b) only supports a constraint to a single type. + template + KernelDefBuilder& TypeConstraint(const char* attr_name); + // TODO(josh11b): Support other types of attr constraints as needed. + + // Specify that this kernel requires/provides an input/output arg + // in host memory (instead of the default, device memory). + // Returns *this. + KernelDefBuilder& HostMemory(const char* arg_name); + + // Specify that this kernel requires a particular value for the + // "_kernel" attr. May only be specified once. Returns *this. + KernelDefBuilder& Label(const char* label); + + // Returns a pointer to a KernelDef with fields set based on the + // above calls to this instance. + // Caller takes ownership of the result. + const KernelDef* Build() { + KernelDef* r = kernel_def_; + kernel_def_ = nullptr; + return r; + } + + private: + KernelDef* kernel_def_; + + TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder); +}; + +// IMPLEMENTATION + +template +inline KernelDefBuilder& KernelDefBuilder::TypeConstraint( + const char* attr_name) { + return this->TypeConstraint(attr_name, DataTypeToEnum::v()); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_ diff --git a/tensorflow/core/framework/kernel_def_builder_test.cc b/tensorflow/core/framework/kernel_def_builder_test.cc new file mode 100644 index 0000000000..eba7144b59 --- /dev/null +++ b/tensorflow/core/framework/kernel_def_builder_test.cc @@ -0,0 +1,76 @@ +#include "tensorflow/core/framework/kernel_def_builder.h" + +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include + +namespace tensorflow { +namespace { + +TEST(KernelDefBuilderTest, Basic) { + const KernelDef* def = KernelDefBuilder("A").Device(DEVICE_CPU).Build(); + KernelDef expected; + protobuf::TextFormat::ParseFromString("op: 'A' device_type: 'CPU'", + &expected); + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; +} + +TEST(KernelDefBuilderTest, TypeConstraint) { + const KernelDef* def = KernelDefBuilder("B") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .Build(); + KernelDef expected; + protobuf::TextFormat::ParseFromString(R"proto( + op: 'B' device_type: 'GPU' + constraint { name: 'T' allowed_values { list { type: DT_FLOAT } } } )proto", + &expected); + + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; + + def = KernelDefBuilder("C") + .Device(DEVICE_GPU) + .TypeConstraint("U") + .TypeConstraint("V") + .Build(); + + protobuf::TextFormat::ParseFromString(R"proto( + op: 'C' device_type: 'GPU' + constraint { name: 'U' allowed_values { list { type: DT_INT32 } } } + constraint { name: 'V' allowed_values { list { type: DT_BOOL } } } )proto", + &expected); + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; + + def = KernelDefBuilder("D") + .Device(DEVICE_CPU) + .TypeConstraint("W", {DT_DOUBLE, DT_STRING}) + .Build(); + protobuf::TextFormat::ParseFromString(R"proto( + op: 'D' device_type: 'CPU' + constraint { name: 'W' + allowed_values { list { type: [DT_DOUBLE, DT_STRING] } } } )proto", + &expected); + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; +} + +TEST(KernelDefBuilderTest, HostMemory) { + const KernelDef* def = KernelDefBuilder("E") + .Device(DEVICE_GPU) + .HostMemory("in") + .HostMemory("out") + .Build(); + KernelDef expected; + protobuf::TextFormat::ParseFromString( + "op: 'E' device_type: 'GPU' " + "host_memory_arg: ['in', 'out']", + &expected); + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc new file mode 100644 index 0000000000..c660b84aa0 --- /dev/null +++ b/tensorflow/core/framework/lookup_interface.cc @@ -0,0 +1,45 @@ +#include "tensorflow/core/framework/lookup_interface.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace lookup { + +Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key, + const Tensor& value) { + if (key.dtype() != key_dtype()) { + return errors::InvalidArgument("Key must be type ", key_dtype(), + " but got ", key.dtype()); + } + if (value.dtype() != value_dtype()) { + return errors::InvalidArgument("Value must be type ", value_dtype(), + " but got ", value.dtype()); + } + if (key.NumElements() != value.NumElements()) { + return errors::InvalidArgument("Number of elements of key(", + key.NumElements(), ") and value(", + value.NumElements(), ") are different."); + } + if (!key.shape().IsSameSize(value.shape())) { + return errors::InvalidArgument("key and value have different shapes."); + } + return Status::OK(); +} + +Status LookupInterface::CheckFindArguments(const Tensor& key, + const Tensor& value, + const Tensor& default_value) { + TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(key, value)); + + if (default_value.dtype() != value_dtype()) { + return errors::InvalidArgument("Default value must be type ", value_dtype(), + " but got ", default_value.dtype()); + } + if (!TensorShapeUtils::IsScalar(default_value.shape())) { + return errors::InvalidArgument("Default values must be scalar."); + } + return Status::OK(); +} + +} // namespace lookup +} // namespace tensorflow diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h new file mode 100644 index 0000000000..d4036d2019 --- /dev/null +++ b/tensorflow/core/framework/lookup_interface.h @@ -0,0 +1,65 @@ +#ifndef TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_ +#define TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace lookup { + +// Lookup interface for batch lookups used by table lookup ops. +class LookupInterface : public ResourceBase { + public: + // Performs batch lookups, for every element in the key tensor, Find returns + // the corresponding value into the values tensor. + // If an element is not present in the table, the given default value is used. + + // For tables that require initialization, Find is available once the table + // is marked as initialized. + + // Returns the following statuses: + // - OK: when the find finishes successfully. + // - FailedPrecondition: if the table is not initialized. + // - InvalidArgument: if any of the preconditions on the lookup key or value + // fails. + // - In addition, other implementations may provide another non-OK status + // specific to their failure modes. + virtual Status Find(const Tensor& keys, Tensor* values, + const Tensor& default_value) = 0; + + // Returns the number of elements in the table. + virtual size_t size() const = 0; + + // Returns the data type of the key. + virtual DataType key_dtype() const = 0; + + // Returns the data type of the value. + virtual DataType value_dtype() const = 0; + + string DebugString() override { return "A lookup table"; } + + protected: + virtual ~LookupInterface() = default; + + // Check format of the key and value tensors. + // Returns OK if all the following requirements are satisfied, otherwise it + // returns InvalidArgument: + // - DataType of the tensor key equals to the table key_dtype + // - DataType of the test value equals to the table value_dtype + // - key and value have the same size and shape + Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values); + + // Check the arguments of a find operation. Returns OK if all the following + // requirements are satisfied, otherwise it returns InvalidArgument: + // - All requirements of CheckKeyAndValueTensors + // - default_value type equals to the table value_dtype + // - default_value is scalar + Status CheckFindArguments(const Tensor& keys, const Tensor& values, + const Tensor& default_value); +}; + +} // namespace lookup +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_ diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc new file mode 100644 index 0000000000..12757f153a --- /dev/null +++ b/tensorflow/core/framework/node_def_builder.cc @@ -0,0 +1,194 @@ +#include "tensorflow/core/framework/node_def_builder.h" + +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name, + const OpRegistryInterface* op_registry) { + node_def_.set_name(name); + Status status; + op_def_ = op_registry->LookUp(op_name, &status); + if (op_def_ == nullptr) { + errors_.push_back(status.error_message()); + inputs_specified_ = 0; + } else { + Initialize(); + } +} + +NodeDefBuilder::NodeDefBuilder(const string& name, const OpDef* op_def) + : op_def_(op_def) { + node_def_.set_name(name); + Initialize(); +} + +void NodeDefBuilder::Initialize() { + inputs_specified_ = 0; + node_def_.set_op(op_def_->name()); +} + +const OpDef::ArgDef* NodeDefBuilder::NextArgDef() { + if (!NextArgAvailable()) return nullptr; + return &op_def_->input_arg(inputs_specified_++); +} + +bool NodeDefBuilder::NextArgAvailable() { + if (op_def_ == nullptr) { + return false; + } else if (inputs_specified_ >= op_def_->input_arg_size()) { + errors_.push_back(strings::StrCat("More Input() calls than the ", + op_def_->input_arg_size(), + " input_args")); + return false; + } + return true; +} + +NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) { + if (NextArgAvailable()) { + Status status = + fake_input(*op_def_, inputs_specified_, node_def_, this); + if (!status.ok()) errors_.push_back(status.error_message()); + } + return *this; +} + +void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg, + const string& src_node, int src_index, + DataType dt) { + AddInput(src_node, src_index); + + if (!input_arg->number_attr().empty() || + !input_arg->type_list_attr().empty()) { + errors_.push_back(strings::StrCat("Single tensor passed to '", + input_arg->name(), "', expected list")); + return; + } + + if (input_arg->type() != DT_INVALID) { + const DataType expected = MaybeAddRef(input_arg, input_arg->type()); + VerifyInputType(input_arg, expected, dt); + } else { + VerifyInputRef(input_arg, dt); + Attr(input_arg->type_attr(), BaseType(dt)); + } +} + +void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg, + gtl::ArraySlice src_list) { + for (const auto& node_out : src_list) { + AddInput(node_out.node, node_out.index); + } + + if (!input_arg->number_attr().empty()) { + Attr(input_arg->number_attr(), static_cast(src_list.size())); + if (input_arg->type() != DT_INVALID) { + const DataType expected = MaybeAddRef(input_arg, input_arg->type()); + for (const auto& node_out : src_list) { + VerifyInputType(input_arg, expected, node_out.data_type); + } + } else if (!src_list.empty()) { + const DataType base = BaseType(src_list[0].data_type); + Attr(input_arg->type_attr(), base); + const DataType expected = MaybeAddRef(input_arg, base); + for (const auto& node_out : src_list) { + VerifyInputType(input_arg, expected, node_out.data_type); + } + } + } else if (!input_arg->type_list_attr().empty()) { + DataTypeVector type_vec; + type_vec.reserve(src_list.size()); + for (const auto& node_out : src_list) { + const DataType dt = node_out.data_type; + VerifyInputRef(input_arg, dt); + type_vec.push_back(BaseType(dt)); + } + Attr(input_arg->type_list_attr(), type_vec); + } else { + errors_.push_back(strings::StrCat("List provided to input '", + input_arg->name(), + "' when single Tensor expected")); + } +} + +void NodeDefBuilder::AddInput(const string& src_node, int src_index) { + if (src_node.empty()) { + errors_.push_back("Empty input node name"); + } else if (src_node[0] == '^') { + errors_.push_back( + strings::StrCat("Non-control input starting with ^: ", src_node)); + } else if (src_index > 0) { + node_def_.add_input(strings::StrCat(src_node, ":", src_index)); + } else { + node_def_.add_input(src_node); + } +} + +void NodeDefBuilder::VerifyInputType(const OpDef::ArgDef* input_arg, + DataType expected, DataType dt) { + if (!TypesCompatible(expected, dt)) { + errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ", + DataTypeString(dt), " expected ", + DataTypeString(expected))); + } +} + +void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg, + DataType dt) { + if (input_arg->is_ref() && !IsRefType(dt)) { + errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ", + DataTypeString(dt), + " expected ref type")); + } +} + +Status NodeDefBuilder::Finalize(NodeDef* node_def) const { + const std::vector* errors_ptr = &errors_; + std::vector errors_storage; + if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) { + // Since this is a const method, to add an error, we have to make + // a copy of the existing errors. + errors_storage = errors_; + errors_storage.push_back( + strings::StrCat(inputs_specified_, " inputs specified of ", + op_def_->input_arg_size(), " inputs in Op")); + errors_ptr = &errors_storage; + } + + if (!errors_ptr->empty()) { + if (errors_ptr->size() == 1) { + if (op_def_ == nullptr) { + return errors::InvalidArgument((*errors_ptr)[0], + " while building NodeDef '", + node_def_.name(), "'"); + } + return errors::InvalidArgument( + (*errors_ptr)[0], " while building NodeDef '", node_def_.name(), + "' using ", SummarizeOpDef(*op_def_)); + } else { + return errors::InvalidArgument( + errors_ptr->size(), " errors while building NodeDef '", + node_def_.name(), "' using ", SummarizeOpDef(*op_def_), ":\n", + str_util::Join(*errors_ptr, "\n")); + } + } else { + NodeDef node_def_backup; + if (node_def == nullptr) node_def = &node_def_backup; + *node_def = node_def_; + + // Add control inputs after the regular inputs. + for (const auto& control_input : control_inputs_) { + node_def->add_input(strings::StrCat("^", control_input)); + } + + // Add default values for unspecified attrs. + AddDefaultsToNodeDef(*op_def_, node_def); + + return Status::OK(); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h new file mode 100644 index 0000000000..706f072608 --- /dev/null +++ b/tensorflow/core/framework/node_def_builder.h @@ -0,0 +1,176 @@ +#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ +#define TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ + +#include +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +class NodeDefBuilder; +typedef std::function FakeInputFunctor; + +// This is a helper for creating a NodeDef. Automatically sets attrs +// that can be inferred from the inputs, and uses default values +// (where they exist) for unspecified attrs. Example usage: +// +// NodeDef node_def; +// Status status = NodeDefBuilder(node_name, op_name) +// .Input(...) +// .Attr(...) +// .Finalize(&node_def); +// if (!status.ok()) return status; +// // Use node_def here. +class NodeDefBuilder { + public: + // To specify an output to be consumed by one of the Input() methods below. + struct NodeOut { + NodeOut(const string& n, int i, DataType dt) + : node(n), index(i), data_type(dt) {} + NodeOut() {} // uninitialized, call Reset() before use. + void Reset(const string& n, int i, DataType dt) { + node = n; + index = i; + data_type = dt; + } + string node; + int index; + DataType data_type; + }; + + // Specify the name and the Op (either via an OpDef or the name of + // the Op plus a registry) for the NodeDef. Other fields are + // specified by calling the methods below. + // REQUIRES: The OpDef must satisfy ValidateOpDef(). + NodeDefBuilder(const string& name, const string& op_name, + const OpRegistryInterface* op_registry = OpRegistry::Global()); + // REQUIRES: in addition, *op_def must outlive *this. + NodeDefBuilder(const string& name, const OpDef* op_def); + + // You must call one Input() function per input_arg in the Op, + // *and in the same order as the input_args appear in the OpDef.* + + // For inputs that take a single tensor. + NodeDefBuilder& Input(const string& src_node, int src_index, DataType dt) { + const OpDef::ArgDef* arg = NextArgDef(); + if (arg != nullptr) SingleInput(arg, src_node, src_index, dt); + return *this; + } + NodeDefBuilder& Input(const NodeOut& src) { + Input(src.node, src.index, src.data_type); + return *this; + } + + // For inputs that take a list of tensors. + NodeDefBuilder& Input(gtl::ArraySlice src_list) { + const OpDef::ArgDef* arg = NextArgDef(); + if (arg != nullptr) ListInput(arg, src_list); + return *this; + } + + // To create inputs in tests, see fake_input.h. + NodeDefBuilder& Input(FakeInputFunctor fake_input); + + // Specify that this node must only run after src_node. + NodeDefBuilder& ControlInput(const string& src_node) { + control_inputs_.push_back(src_node); + return *this; + } + + // Constrains what devices this node may be scheduled on. + NodeDefBuilder& Device(const string& device_spec) { + node_def_.set_device(device_spec); + return *this; + } + + // Sets the attr, if not already set. If already set with a different + // value, an error will be returned from Finalize(). + template + NodeDefBuilder& Attr(const string& attr_name, T&& value); + // Note: overload needed to allow {...} expressions for value. + template + NodeDefBuilder& Attr(const string& attr_name, + std::initializer_list value) { + Attr>(attr_name, std::move(value)); + return *this; + } + + // Finish building the NodeDef, returning any errors or setting + // *node_def if none. + // WARNING: Not all problems are detected! The resulting NodeDef may + // not be valid! Call ValidateNodeDef() from node_def_utils to be sure. + Status Finalize(NodeDef* node_def) const; + + // Accessor for the OpDef set in the constructor. + const OpDef& op_def() const { return *op_def_; } + + private: + // Called in the constructors. + void Initialize(); + + // Get the current ArgDef and advance to the next one. Returns nullptr + // if no more inputs are available. + const OpDef::ArgDef* NextArgDef(); + + // Returns true if there is still an input_arg available in *op_def_, + // otherwise adds to error_ and returns false. + bool NextArgAvailable(); + + // These do the main work of the Input() methods. + void SingleInput(const OpDef::ArgDef* input_arg, const string& src_node, + int src_index, DataType dt); + void ListInput(const OpDef::ArgDef* input_arg, + gtl::ArraySlice src_list); + + // Add "src_node:src_index" to the list of inputs in the node_def_. + void AddInput(const string& src_node, int src_index); + + // Generate an error if you can't pass dt when expected is expected. + void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected, + DataType dt); + + // If input_arg->is_ref() is true, generate an error if dt is not a ref. + void VerifyInputRef(const OpDef::ArgDef* input_arg, DataType dt); + + // Makes dt a ref type if that is what the input_arg specifies. + DataType MaybeAddRef(const OpDef::ArgDef* input_arg, DataType dt) { + return input_arg->is_ref() ? MakeRefType(dt) : dt; + } + + const OpDef* op_def_; + NodeDef node_def_; + int inputs_specified_; + std::vector control_inputs_; + std::vector errors_; +}; + +// IMPLEMENTATION ------------------------------------------------------------- + +template +NodeDefBuilder& NodeDefBuilder::Attr(const string& attr_name, T&& value) { + const AttrValue* found = AttrSlice(node_def_).Find(attr_name); + if (found == nullptr) { + AddNodeAttr(attr_name, std::forward(value), &node_def_); + } else { + AttrValue attr_value; + SetAttrValue(std::forward(value), &attr_value); + if (!AreAttrValuesEqual(*found, attr_value)) { + errors_.push_back(strings::StrCat( + "Inconsistent values for attr '", attr_name, "' ", + SummarizeAttrValue(*found), " vs. ", SummarizeAttrValue(attr_value))); + } + } + return *this; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc new file mode 100644 index 0000000000..6fd4a8d1ed --- /dev/null +++ b/tensorflow/core/framework/node_def_builder_test.cc @@ -0,0 +1,1036 @@ +#include "tensorflow/core/framework/node_def_builder.h" + +#include +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include + +namespace tensorflow { +namespace { + +class NodeDefBuilderTest : public ::testing::Test { + protected: + // Specify an OpDef via an OpDefBuilder. + void Op(const OpDefBuilder& op_def_builder) { + EXPECT_OK(op_def_builder.Finalize(&op_def_)); + } + + // Resets builder_ with a new NodeDefBuilder using the Op from the last call + // to Op() above. + NodeDefBuilder& Builder() { + EXPECT_FALSE(op_def_.name().empty()) << "Must call Op() before Builder()"; + builder_.reset(new NodeDefBuilder("n", &op_def_)); + return *builder_; + } + + // Calls Finalize() and verifies it returns success and the result matches + // expectations. + void ExpectSuccess(const NodeDefBuilder& builder, + DataTypeSlice expected_in_types, + DataTypeSlice expected_out_types, StringPiece proto) { + NodeDef node_def; + Status status = builder.Finalize(&node_def); + EXPECT_OK(status); + if (!status.ok()) return; + NodeDef expected; + protobuf::TextFormat::ParseFromString(strings::StrCat("name: 'n' ", proto), + &expected); + EXPECT_EQ(node_def.DebugString(), expected.DebugString()); + + DataTypeVector in_types, out_types; + status = + InOutTypesForNode(node_def, builder.op_def(), &in_types, &out_types); + EXPECT_OK(status); + if (!status.ok()) return; + EXPECT_EQ(DataTypeSliceString(expected_in_types), + DataTypeVectorString(in_types)); + EXPECT_EQ(DataTypeSliceString(expected_out_types), + DataTypeVectorString(out_types)); + + status = ValidateNodeDef(node_def, op_def_); + EXPECT_OK(status); + } + + // Calls Finalize() and verifies it returns an error. + // Each message must appear as a substring of the error. + void ExpectFailures(const NodeDefBuilder& builder, + const std::vector& messages) { + NodeDef node_def; + Status status = builder.Finalize(&node_def); + EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); + if (status.ok()) return; + for (const string& message : messages) { + EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) + << status << ", " << message; + } + } + + // Calls Finalize() and verifies it returns an error. + // Message must appear as a substring of the error. + void ExpectFailure(const NodeDefBuilder& builder, const string& message) { + ExpectFailures(builder, {message}); + } + + // Like ExpectFailure(), except that the error can come from + // ValidateNodeDef(). + void ExpectInvalid(const NodeDefBuilder& builder, const string& message) { + NodeDef node_def; + Status status = builder.Finalize(&node_def); + if (status.ok()) { + status = ValidateNodeDef(node_def, op_def_); + } + EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); + if (status.ok()) return; + EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) + << "Actual error: " << status.error_message() + << "\nDoes not contain: " << message; + } + + OpDef op_def_; + std::unique_ptr builder_; +}; + +TEST_F(NodeDefBuilderTest, Simple) { + Op(OpDefBuilder("Simple").Input("a: int32").Output("out: float")); + + ExpectSuccess(Builder().Input("x", 0, DT_INT32), {DT_INT32}, {DT_FLOAT}, + R"proto( op: "Simple" input: "x" )proto"); + + // Port != 0 + ExpectSuccess(Builder().Input("y", 2, DT_INT32), {DT_INT32}, {DT_FLOAT}, + R"proto( op: "Simple" input: "y:2" )proto"); + + // FakeInput + ExpectSuccess(Builder().Input(FakeInput()), {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: "a" )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_FLOAT}, + R"proto( op: "Simple" input: "a" )proto"); + + // Ref input + ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32}, + {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto"); + + // ControlInput + ExpectSuccess( + Builder().ControlInput("x").Input(FakeInput()).ControlInput("y"), + {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: ["a", "^x", "^y"] )proto"); + + // Device + ExpectSuccess(Builder().Input(FakeInput()).Device("ddd"), {DT_INT32}, + {DT_FLOAT}, R"proto( + op: "Simple" input: "a" device: "ddd" )proto"); + + // Extra input + ExpectFailure(Builder().Input("x", 0, DT_INT32).Input("y", 0, DT_INT32), + "More Input() calls than the 1 input_args while building " + "NodeDef 'n' using Op " + "out:float>"); + + // Missing input + ExpectFailure(Builder(), "0 inputs specified of 1 inputs in Op while"); + + { // Finalize() twice. + NodeDefBuilder& builder = Builder(); + builder.Input(FakeInput()).Finalize(nullptr); // First call to Finalize() + // ExpectSuccess() also calls Finalize(). + ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: "a" )proto"); + } + + { // Input() after Finalize() + NodeDefBuilder& builder = Builder(); + // Calling Finalize() before enough inputs -> error. + ExpectFailure(builder, "0 inputs specified of 1 inputs in Op while"); + builder.Input(FakeInput()); + // Calling Finalize() with enough inputs -> success + ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: "a" )proto"); + // Calling Finalize() with too many inputs -> error. + builder.Input(FakeInput(DT_INT32)); + ExpectFailure(builder, "More Input() calls than the 1 input_args while"); + } + + // Wrong input type + ExpectFailure(Builder().Input("x", 0, DT_FLOAT), + "Input 'a' passed float expected int32 "); + + ExpectFailure(Builder().Input("x", 0, DT_FLOAT_REF), + "Input 'a' passed float_ref expected int32 "); + + // List input + ExpectFailure(Builder().Input(FakeInput(3, DT_FLOAT)), + "List provided to input 'a' when single Tensor expected while"); + + ExpectFailure(Builder().Input(FakeInput(3)), + "List provided to input 'a' when single Tensor expected while"); + + // Bad ControlInput + ExpectInvalid(Builder().Input(FakeInput()).ControlInput("z:2"), + "Control input '^z:2' must not have ':' in NodeDef:"); + + // Bad input name + ExpectFailure(Builder().Input("", 0, DT_INT32), + "Empty input node name while"); + + ExpectFailure(Builder().Input("^x", 0, DT_INT32), + "Non-control input starting with ^: ^x while"); +} + +TEST_F(NodeDefBuilderTest, OpDoesNotExist) { + NodeDefBuilder builder("n", "Op Does Not Exist"); + builder.Input(FakeInput()) + .Input(FakeInput(12)) + .ControlInput("y") + .Attr("foo", 12) + .Device("device"); + ExpectFailure( + builder, + "Op type not registered 'Op Does Not Exist' while building NodeDef 'n'"); +} + +TEST_F(NodeDefBuilderTest, Polymorphic) { + Op(OpDefBuilder("Polymorphic") + .Input("v: T") + .Output("out: T") + .Attr("T: type")); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_INT32}, + R"proto( + op: "Polymorphic" input: "a" + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DT_FLOAT)), {DT_FLOAT}, {DT_FLOAT}, + R"proto( + op: "Polymorphic" input: "a" + attr { key: "T" value { type: DT_FLOAT } } )proto"); + + // Redundant Attr() + ExpectSuccess(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_BOOL), + {DT_BOOL}, {DT_BOOL}, R"proto( + op: "Polymorphic" input: "a" + attr { key: "T" value { type: DT_BOOL } } )proto"); + + // Conficting Attr() + ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_STRING), + "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while"); + + ExpectFailure(Builder().Attr("T", DT_STRING).Input(FakeInput(DT_BOOL)), + "Inconsistent values for attr 'T' DT_STRING vs. DT_BOOL while"); + + ExpectFailure(Builder().Attr("T", 12).Input(FakeInput(DT_BOOL)), + "Inconsistent values for attr 'T' 12 vs. DT_BOOL while"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicOut) { + Op(OpDefBuilder("PolymorphicOut").Output("out: T").Attr("T: type")); + + ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32}, R"proto( + op: "PolymorphicOut" + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Attr("T", DT_FLOAT), {}, {DT_FLOAT}, R"proto( + op: "PolymorphicOut" + attr { key: "T" value { type: DT_FLOAT } } )proto"); + + // Redundant attr + ExpectSuccess(Builder().Attr("T", DT_FLOAT).Attr("T", DT_FLOAT), {}, + {DT_FLOAT}, R"proto( + op: "PolymorphicOut" + attr { key: "T" value { type: DT_FLOAT } } )proto"); + + // Conflicting attr + ExpectFailure(Builder().Attr("T", DT_BOOL).Attr("T", DT_FLOAT), + "Inconsistent values for attr 'T' DT_BOOL vs. DT_FLOAT while"); + + // Missing attr + ExpectInvalid(Builder(), "NodeDef missing attr 'T' from"); + + // Attr has the wrong type + ExpectInvalid(Builder().Attr("T", {DT_INT32, DT_BOOL}), + "AttrValue had value with type list(type) when type expected"); + + ExpectInvalid(Builder().Attr("T", 12), + "AttrValue had value with type int when type expected"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicDefaultOut) { + Op(OpDefBuilder("PolymorphicDefaultOut") + .Output("out: T") + .Attr("T: type = DT_STRING")); + + ExpectSuccess(Builder(), {}, {DT_STRING}, R"proto( + op: "PolymorphicDefaultOut" + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectSuccess(Builder().Attr("T", DT_BOOL), {}, {DT_BOOL}, R"proto( + op: "PolymorphicDefaultOut" + attr { key: "T" value { type: DT_BOOL } } )proto"); +} + +TEST_F(NodeDefBuilderTest, Binary) { + Op(OpDefBuilder("Binary").Input("a: T").Input("b: T").Output("out: T").Attr( + "T: type")); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32)).Input(FakeInput(DT_INT32)), + {DT_INT32, DT_INT32}, {DT_INT32}, R"proto( + op: "Binary" input: "a" input: "b" + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DT_STRING)).Input(FakeInput()), + {DT_STRING, DT_STRING}, {DT_STRING}, R"proto( + op: "Binary" input: "a" input: "b" + attr { key: "T" value { type: DT_STRING } } )proto"); + + // Type mismatch + ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Input(FakeInput(DT_STRING)), + "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while"); +} + +TEST_F(NodeDefBuilderTest, Restrict) { + Op(OpDefBuilder("Restrict") + .Input("a: T") + .Output("out: T") + .Attr("T: {string, bool}")); + ExpectSuccess(Builder().Input(FakeInput(DT_STRING)), {DT_STRING}, {DT_STRING}, + R"proto( + op: "Restrict" input: "a" + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectInvalid(Builder().Input(FakeInput(DT_INT32)), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, TypeList) { + Op(OpDefBuilder("TypeList").Input("a: T").Attr("T: list(type)")); + + ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_INT32})), + {DT_STRING, DT_INT32}, {}, R"proto( + op: "TypeList" input: ["a", "a:1"] + attr { key: "T" value { list { type: [DT_STRING, DT_INT32] } } } + )proto"); + + ExpectSuccess(Builder().Input(FakeInput(3, DT_BOOL)), + {DT_BOOL, DT_BOOL, DT_BOOL}, {}, R"proto( + op: "TypeList" input: ["a", "a:1", "a:2"] + attr { key: "T" value { list { type: [DT_BOOL, DT_BOOL, DT_BOOL] } } } + )proto"); + + ExpectInvalid(Builder().Input(FakeInput(0)), + "Length for attr 'T' of 0 must be at least minimum 1"); + + ExpectInvalid(Builder().Input(FakeInput({})), + "Length for attr 'T' of 0 must be at least minimum 1"); + + ExpectInvalid(Builder().Input(FakeInput(DT_BOOL)), + "Single tensor passed to 'a', expected list while"); + + ExpectFailures(Builder().Input(FakeInput()), + {"2 errors while building NodeDef", + "Could not infer list of types for input 'a': " + "No attr named 'T' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); +} + +TEST_F(NodeDefBuilderTest, TypeListNoMin) { + Op(OpDefBuilder("TypeListNoMin").Input("a: T").Attr("T: list(type) >= 0")); + + ExpectSuccess(Builder().Input(FakeInput(0)), {}, {}, R"proto( + op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DataTypeVector())), {}, {}, R"proto( + op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput({})), {}, {}, R"proto( + op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput({DT_BOOL})), {DT_BOOL}, {}, R"proto( + op: "TypeListNoMin" input: "a" + attr { key: "T" value { list { type: DT_BOOL } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, TypeListTwice) { + Op(OpDefBuilder("TypeListTwice") + .Input("a: T") + .Input("b: T") + .Attr("T: list(type) >= 0")); + + ExpectSuccess(Builder() + .Input(FakeInput({DT_INT32, DT_BOOL})) + .Input(FakeInput({DT_INT32, DT_BOOL})), + {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto( + op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"] + attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto"); + + ExpectSuccess( + Builder().Input(FakeInput({DT_INT32, DT_BOOL})).Input(FakeInput()), + {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto( + op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"] + attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(0)), {}, {}, + R"proto( + op: "TypeListTwice" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {}, + R"proto( + op: "TypeListTwice" attr { key: "T" value { list { } } } )proto"); + + ExpectFailure(Builder() + .Input(FakeInput({DT_INT32, DT_BOOL})) + .Input(FakeInput({DT_INT32, DT_STRING})), + "Inconsistent values for attr 'T' [DT_INT32, DT_BOOL] vs. " + "[DT_INT32, DT_STRING] while"); +} + +TEST_F(NodeDefBuilderTest, OutTypeList) { + Op(OpDefBuilder("OutTypeList").Output("out: T").Attr("T: list(type) >= 0")); + + ExpectSuccess(Builder().Attr("T", {DT_FLOAT}), {}, {DT_FLOAT}, R"proto( + op: "OutTypeList" + attr { key: "T" value { list { type: DT_FLOAT } } } )proto"); + + ExpectSuccess(Builder().Attr("T", {DT_STRING, DT_BOOL}), {}, + {DT_STRING, DT_BOOL}, R"proto( + op: "OutTypeList" + attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto"); + + ExpectSuccess(Builder().Attr("T", DataTypeVector()), {}, {}, R"proto( + op: "OutTypeList" + attr { key: "T" value { list { } } } )proto"); + + ExpectInvalid(Builder().Attr("T", DT_FLOAT), + "AttrValue had value with type type when list(type) expected"); +} + +TEST_F(NodeDefBuilderTest, TypeListRestrict) { + Op(OpDefBuilder("TypeListRestrict") + .Input("a: T") + .Attr("T: list({string, bool}) >= 0")); + + ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_BOOL})), + {DT_STRING, DT_BOOL}, {}, R"proto( + op: "TypeListRestrict" input: ["a", "a:1"] + attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto"); + + ExpectInvalid(Builder().Input(FakeInput({DT_STRING, DT_INT32})), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, OutTypeListRestrict) { + Op(OpDefBuilder("OutTypeListRestrict") + .Output("out: t") + .Attr("t: list({string, bool}) >= 0")); + + ExpectSuccess(Builder().Attr("t", {DT_BOOL, DT_STRING}), {}, + {DT_BOOL, DT_STRING}, R"proto( + op: "OutTypeListRestrict" + attr { key: "t" value { list { type: [DT_BOOL, DT_STRING] } } } )proto"); + + ExpectInvalid(Builder().Attr("t", {DT_STRING, DT_INT32}), + "Value for attr 't' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, Attr) { + Op(OpDefBuilder("Attr").Attr("a: int")); + + ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto( + op: "Attr" attr { key: "a" value { i: 12 } } )proto"); + + // Attr has wrong type + ExpectInvalid(Builder().Attr("a", "bad"), + "AttrValue had value with type string when int expected"); + + ExpectInvalid(Builder().Attr("a", {12}), + "AttrValue had value with type list(int) when int expected"); + + // Missing attr + ExpectInvalid(Builder(), "NodeDef missing attr 'a' from Op<"); + + // Wrong attr + ExpectInvalid(Builder().Attr("b", 12), + "NodeDef mentions attr 'b' not in Op<"); + + // Extra attr + ExpectInvalid(Builder().Attr("a", 12).Attr("extra", 12), + "NodeDef mentions attr 'extra' not in Op<"); +} + +TEST_F(NodeDefBuilderTest, AttrFloat) { + Op(OpDefBuilder("AttrFloat").Attr("a: float")); + + ExpectSuccess(Builder().Attr("a", 1.2f /* float */), {}, {}, R"proto( + op: "AttrFloat" attr { key: "a" value { f: 1.2 } } + )proto"); + + ExpectSuccess(Builder().Attr("a", 1.2 /* double */), {}, {}, R"proto( + op: "AttrFloat" attr { key: "a" value { f: 1.2 } } + )proto"); + + // Won't automatically cast int to float + ExpectInvalid(Builder().Attr("a", 12), + "AttrValue had value with type int when float expected"); +} + +TEST_F(NodeDefBuilderTest, AttrBoolList) { + Op(OpDefBuilder("AttrBoolList").Attr("a: list(bool)")); + + ExpectSuccess(Builder().Attr("a", {true, false, true}), {}, {}, R"proto( + op: "AttrBoolList" + attr { key: "a" value { list { b: [true, false, true] } } } + )proto"); + + ExpectSuccess(Builder().Attr("a", std::vector()), {}, {}, R"proto( + op: "AttrBoolList" attr { key: "a" value { list { } } } + )proto"); + + // Won't cast int -> bool. + ExpectInvalid(Builder().Attr("a", {0}), + "AttrValue had value with type list(int) when list(bool) " + "expected"); +} + +TEST_F(NodeDefBuilderTest, AttrMin) { + Op(OpDefBuilder("AttrMin").Attr("a: int >= 5")); + + ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto( + op: "AttrMin" attr { key: "a" value { i: 12 } } )proto"); + + ExpectInvalid(Builder().Attr("a", 2), + "Value for attr 'a' of 2 must be at least minimum 5"); +} + +TEST_F(NodeDefBuilderTest, AttrListMin) { + Op(OpDefBuilder("AttrListMin").Attr("a: list(int) >= 2")); + + ExpectSuccess(Builder().Attr("a", {1, 2}), {}, {}, R"proto( + op: "AttrListMin" + attr { key: "a" value { list { i: [1, 2] } } } )proto"); + + ExpectInvalid(Builder().Attr("a", {17}), + "Length for attr 'a' of 1 must be at least minimum 2"); +} + +TEST_F(NodeDefBuilderTest, AttrEnum) { + Op(OpDefBuilder("AttrEnum").Attr("a: {'apples', 'oranges'}")); + + ExpectSuccess(Builder().Attr("a", "oranges"), {}, {}, R"proto( + op: "AttrEnum" + attr { key: "a" value { s: "oranges" } } )proto"); + + ExpectInvalid( + Builder().Attr("a", "invalid"), + "Value for attr 'a' of \"invalid\" is not in the list of allowed values: " + "\"apples\", \"oranges\""); +} + +TEST_F(NodeDefBuilderTest, AttrEnumList) { + Op(OpDefBuilder("AttrEnumList").Attr("a: list({'apples', 'oranges'})")); + + ExpectSuccess(Builder().Attr("a", {"oranges", "apples"}), {}, {}, R"proto( + op: "AttrEnumList" + attr { key: "a" value { list { s: ["oranges", "apples"] } } } )proto"); + + ExpectInvalid( + Builder().Attr("a", {"apples", "invalid", "oranges"}), + "Value for attr 'a' of \"invalid\" is not in the list of allowed values: " + "\"apples\", \"oranges\""); +} + +TEST_F(NodeDefBuilderTest, AttrShape) { + Op(OpDefBuilder("AttrShape").Attr("a: shape")); + + ExpectSuccess(Builder().Attr("a", TensorShape({5})), {}, {}, R"proto( + op: "AttrShape" + attr { key: "a" value { shape { dim { size: 5 } } } } )proto"); + + ExpectSuccess(Builder().Attr("a", TensorShape({4, 3, 2})), {}, {}, R"proto( + op: "AttrShape" + attr { key: "a" value { shape { + dim { size: 4 } dim { size: 3 } dim { size: 2 } } } } )proto"); + + ExpectSuccess(Builder().Attr("a", TensorShape({3, 2})), {}, {}, + R"proto( + op: "AttrShape" + attr { key: "a" value { shape { + dim { size: 3 } dim { size: 2 } } } } )proto"); + + ExpectSuccess(Builder().Attr("a", TensorShape()), {}, {}, R"proto( + op: "AttrShape" + attr { key: "a" value { shape { } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrDefault) { + Op(OpDefBuilder("AttrDefault").Attr("a: string = 'banana'")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrDefault" + attr { key: "a" value { s: "banana" } } )proto"); + + ExpectSuccess(Builder().Attr("a", "kiwi"), {}, {}, R"proto( + op: "AttrDefault" + attr { key: "a" value { s: "kiwi" } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrManyDefault) { + Op(OpDefBuilder("AttrManyDefault") + .Attr("a: string = 'banana'") + .Attr("b: string = 'kiwi'")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrManyDefault" + attr { key: "a" value { s: "banana" } } + attr { key: "b" value { s: "kiwi" } } )proto"); + + Op(OpDefBuilder("AttrManyDefaultWithMandatory") + .Attr("a: string = 'banana'") + .Attr("b: string = 'kiwi'") + .Attr("c: string")); + + ExpectSuccess(Builder().Attr("c", "strawberry"), {}, {}, R"proto( + op: "AttrManyDefaultWithMandatory" + attr { key: "c" value { s: "strawberry" } } + attr { key: "a" value { s: "banana" } } + attr { key: "b" value { s: "kiwi" } } )proto"); + + Op(OpDefBuilder("AttrManyDefaultAndInferred") + .Input("input: T") + .Attr("T: {float, double}") + .Attr("a: string") + .Attr("b: list(string) >= 1") + .Attr("c: bool = true") + .Attr("d: float = 0.3") + .Attr("e: string") + .Attr("f: float = 0.25")); + + ExpectSuccess(Builder() + .Input(FakeInput(DT_FLOAT)) + .Attr("a", "foo") + .Attr("e", "foo") + .Attr("b", std::vector({"bar", "baz"})) + .Attr("f", 1.0f), + {DT_FLOAT}, {}, R"proto( + op: "AttrManyDefaultAndInferred" + input: "a" + attr { key: "T" value { type: DT_FLOAT } } + attr { key: "a" value { s: "foo" } } + attr { key: "e" value { s: "foo" } } + attr { key: "b" value { list { s: "bar" s: "baz" } } } + attr { key: "f" value { f: 1.0 } } + attr { key: "c" value { b: true } } + attr { key: "d" value { f: 0.3 } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrListDefault) { + Op(OpDefBuilder("AttrListDefault").Attr("a: list(int) = [5, 15]")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrListDefault" + attr { key: "a" value { list { i: [5, 15] } } } )proto"); + + ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto( + op: "AttrListDefault" + attr { key: "a" value { list { i: 3 } } } )proto"); + + ExpectSuccess(Builder().Attr("a", std::vector()), {}, {}, R"proto( + op: "AttrListDefault" + attr { key: "a" value { list { } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrEmptyListDefault) { + Op(OpDefBuilder("AttrEmptyListDefault").Attr("a: list(int) = []")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrEmptyListDefault" + attr { key: "a" value { list { } } } )proto"); + + ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto( + op: "AttrEmptyListDefault" + attr { key: "a" value { list { i: 3 } } } )proto"); + + ExpectSuccess(Builder().Attr("a", std::vector()), {}, {}, R"proto( + op: "AttrEmptyListDefault" + attr { key: "a" value { list { } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, NIntsIn) { + Op(OpDefBuilder("NIntsIn").Input("a: N*int32").Attr("N: int >= 2")); + + ExpectSuccess(Builder().Input(FakeInput(2)), {DT_INT32, DT_INT32}, {}, + R"proto( + op: "NIntsIn" input: ["a", "a:1"] + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(5, DT_INT32)), + {DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( + op: "NIntsIn" + input: ["a", "a:1", "a:2", "a:3", "a:4"] + attr { key: "N" value { i: 5 } } )proto"); + + ExpectFailures(Builder().Input(FakeInput(2, DT_STRING)), + {"2 errors while building NodeDef", + "Input 'a' passed string expected int32"}); + + ExpectInvalid(Builder().Input(FakeInput(1)), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectFailures( + Builder().Input(FakeInput(DT_INT32)), + {"2 errors while building NodeDef", + "Could not infer length of input 'a': No attr named 'N' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); + + ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}), + "Input 'a' passed string expected int32 while"); + + ExpectFailures( + Builder().Input(FakeInput()), + {"2 errors while building NodeDef", + "Could not infer length of input 'a': No attr named 'N' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicIn) { + Op(OpDefBuilder("NPolymorphicIn") + .Input("a: N*T") + .Attr("T: type") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)), {DT_INT32, DT_INT32}, + {}, R"proto( + op: "NPolymorphicIn" input: ["a", "a:1"] + attr { key: "N" value { i: 2 } } + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)), + {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto( + op: "NPolymorphicIn" + input: ["a", "a:1", "a:2"] + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectFailures( + Builder().Input(FakeInput(2)), + {"2 errors while building NodeDef", + "Could not infer type for input 'a': No attr named 'T' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); + + ExpectFailure(Builder().Input(FakeInput({DT_INT32, DT_STRING})), + "Input 'a' passed string expected int32 while"); + + ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}), + "Input 'a' passed string expected int32 while"); + + ExpectInvalid(Builder().Input(FakeInput(1, DT_INT32)), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectFailure(Builder().Input("in", 0, DT_INT32), + "Single tensor passed to 'a', expected list while"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicRestrictIn) { + Op(OpDefBuilder("NPolymorphicRestrictIn") + .Input("a: N*T") + .Attr("T: {string, bool}") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Input(FakeInput(2, DT_BOOL)), {DT_BOOL, DT_BOOL}, {}, + R"proto( + op: "NPolymorphicRestrictIn" input: ["a", "a:1"] + attr { key: "N" value { i: 2 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)), + {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto( + op: "NPolymorphicRestrictIn" + input: ["a", "a:1", "a:2"] + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectInvalid(Builder().Input(FakeInput(2, DT_INT32)), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, NInTwice) { + Op(OpDefBuilder("NInTwice") + .Input("a: N*int32") + .Input("b: N*string") + .Attr("N: int >= 0")); + + ExpectSuccess(Builder().Input(FakeInput(2)).Input(FakeInput(2)), + {DT_INT32, DT_INT32, DT_STRING, DT_STRING}, {}, R"proto( + op: "NInTwice" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {}, + R"proto( + op: "NInTwice" attr { key: "N" value { i: 0 } } )proto"); + + ExpectFailure(Builder().Input(FakeInput(3)).Input(FakeInput(1)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); +} + +TEST_F(NodeDefBuilderTest, NInPolymorphicTwice) { + Op(OpDefBuilder("NInPolymorphicTwice") + .Input("a: N*T") + .Input("b: N*T") + .Attr("T: type") + .Attr("N: int >= 0")); + + ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput()), + {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( + op: "NInPolymorphicTwice" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectFailure( + Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_INT32)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); + + ExpectFailure(Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); + + ExpectFailure( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)), + "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); + + ExpectFailure( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_STRING)), + "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); +} + +TEST_F(NodeDefBuilderTest, NInTwoTypeVariables) { + Op(OpDefBuilder("NInTwoTypeVariables") + .Input("a: N*S") + .Input("b: N*T") + .Attr("S: type") + .Attr("T: type") + .Attr("N: int >= 0")); + + ExpectSuccess( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_BOOL)), + {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto( + op: "NInTwoTypeVariables" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } + attr { key: "S" value { type: DT_INT32 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectSuccess( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_BOOL)), + {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto( + op: "NInTwoTypeVariables" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } + attr { key: "S" value { type: DT_INT32 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectFailure( + Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_STRING)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); +} + +TEST_F(NodeDefBuilderTest, InPolymorphicTwice) { + Op(OpDefBuilder("InPolymorphicTwice") + .Input("a: N*T") + .Input("b: M*T") + .Attr("T: type") + .Attr("N: int >= 0") + .Attr("M: int >= 0")); + + ExpectSuccess( + Builder().Input(FakeInput(1, DT_INT32)).Input(FakeInput(3, DT_INT32)), + {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( + op: "InPolymorphicTwice" + input: ["a", "b", "b:1", "b:2"] + attr { key: "N" value { i: 1 } } + attr { key: "T" value { type: DT_INT32 } } + attr { key: "M" value { i: 3 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(1, DT_BOOL)).Input(FakeInput(0)), + {DT_BOOL}, {}, R"proto( + op: "InPolymorphicTwice" input: "a" + attr { key: "N" value { i: 1 } } + attr { key: "T" value { type: DT_BOOL } } + attr { key: "M" value { i: 0 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(1, DT_BOOL)), + {DT_BOOL}, {}, R"proto( + op: "InPolymorphicTwice" input: "b" + attr { key: "N" value { i: 0 } } + attr { key: "M" value { i: 1 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectFailure( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)), + "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); +} + +TEST_F(NodeDefBuilderTest, NIntsOut) { + Op(OpDefBuilder("NIntsOut").Output("a: N*int32").Attr("N: int >= 2")); + + ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( + op: "NIntsOut" + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3), {}, {DT_INT32, DT_INT32, DT_INT32}, + R"proto( + op: "NIntsOut" + attr { key: "N" value { i: 3 } } )proto"); + + ExpectInvalid(Builder().Attr("N", 1), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectInvalid(Builder().Attr("N", {3}), + "AttrValue had value with type list(int) when int expected"); + + ExpectInvalid(Builder(), "NodeDef missing attr 'N' from"); +} + +TEST_F(NodeDefBuilderTest, NIntsOutDefault) { + Op(OpDefBuilder("NIntsOutDefault") + .Output("a: N*int32") + .Attr("N: int >= 2 = 3")); + + ExpectSuccess(Builder(), {}, {DT_INT32, DT_INT32, DT_INT32}, R"proto( + op: "NIntsOutDefault" + attr { key: "N" value { i: 3 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( + op: "NIntsOutDefault" + attr { key: "N" value { i: 2 } } )proto"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicOut) { + Op(OpDefBuilder("NPolymorphicOut") + .Output("a: N*T") + .Attr("T: type") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Attr("T", DT_INT32).Attr("N", 2), {}, + {DT_INT32, DT_INT32}, R"proto( + op: "NPolymorphicOut" + attr { key: "T" value { type: DT_INT32 } } + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_STRING), {}, + {DT_STRING, DT_STRING, DT_STRING}, R"proto( + op: "NPolymorphicOut" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectInvalid(Builder().Attr("N", 1).Attr("T", DT_STRING), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectInvalid(Builder().Attr("N", 3).Attr("T", {DT_STRING}), + "AttrValue had value with type list(type) when type expected"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicOutDefault) { + Op(OpDefBuilder("NPolymorphicOutDefault") + .Output("a: N*T") + .Attr("T: type = DT_BOOL") + .Attr("N: int >= 2 = 2")); + + ExpectSuccess(Builder(), {}, {DT_BOOL, DT_BOOL}, R"proto( + op: "NPolymorphicOutDefault" + attr { key: "T" value { type: DT_BOOL } } + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3), {}, {DT_BOOL, DT_BOOL, DT_BOOL}, + R"proto( + op: "NPolymorphicOutDefault" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32, DT_INT32}, + R"proto( + op: "NPolymorphicOutDefault" + attr { key: "T" value { type: DT_INT32 } } + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_INT32), {}, + {DT_INT32, DT_INT32, DT_INT32}, R"proto( + op: "NPolymorphicOutDefault" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_INT32 } } )proto"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicRestrictOut) { + Op(OpDefBuilder("NPolymorphicRestrictOut") + .Output("a: N*T") + .Attr("T: {string, bool}") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_BOOL), {}, + {DT_BOOL, DT_BOOL, DT_BOOL}, R"proto( + op: "NPolymorphicRestrictOut" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectInvalid(Builder().Attr("N", 3).Attr("T", DT_INT32), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, RefIn) { + Op(OpDefBuilder("RefIn").Input("a: Ref(int32)")); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32_REF}, {}, + R"proto( + op: "RefIn" input: "a" )proto"); + + ExpectFailure(Builder().Input(FakeInput(DT_BOOL_REF)), + "Input 'a' passed bool_ref expected int32_ref while"); + + ExpectFailure(Builder().Input(FakeInput(DT_INT32)), + "Input 'a' passed int32 expected int32_ref while"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicRefIn) { + Op(OpDefBuilder("PolymorphicRefIn").Input("a: Ref(T)").Attr("T: type")); + + ExpectSuccess(Builder().Input(FakeInput(DT_BOOL_REF)), {DT_BOOL_REF}, {}, + R"proto( + op: "PolymorphicRefIn" input: "a" + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectFailure(Builder().Input(FakeInput(DT_BOOL)), + "Input 'a' passed bool expected ref type while"); +} + +TEST_F(NodeDefBuilderTest, RefOut) { + Op(OpDefBuilder("RefOut").Output("a: Ref(string)")); + + ExpectSuccess(Builder(), {}, {DT_STRING_REF}, R"proto( + op: "RefOut" )proto"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicRefOut) { + Op(OpDefBuilder("PolymorphicRefOut").Output("a: Ref(t)").Attr("t: type")); + + ExpectSuccess(Builder().Attr("t", DT_BOOL), {}, {DT_BOOL_REF}, R"proto( + op: "PolymorphicRefOut" + attr { key: "t" value { type: DT_BOOL } } )proto"); +} + +TEST_F(NodeDefBuilderTest, SpecifyDevice) { + Op(OpDefBuilder("SpecifyDevice")); + + ExpectSuccess(Builder().Device("ADevice"), {}, {}, R"proto( + op: "SpecifyDevice" device: "ADevice" )proto"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc new file mode 100644 index 0000000000..aefd416187 --- /dev/null +++ b/tensorflow/core/framework/node_def_util.cc @@ -0,0 +1,414 @@ +#include "tensorflow/core/framework/node_def_util.h" + +#include +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +string SummarizeNodeDef(const NodeDef& node_def) { + string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "["); + + // We sort the attrs so the output is deterministic. + std::vector attr_names; + attr_names.reserve(node_def.attr().size()); + for (const auto& attr : node_def.attr()) { + attr_names.push_back(attr.first); + } + std::sort(attr_names.begin(), attr_names.end()); + bool first = true; + for (const string& attr_name : attr_names) { + if (!first) strings::StrAppend(&ret, ", "); + first = false; + auto iter = node_def.attr().find(attr_name); + strings::StrAppend(&ret, attr_name, "=", SummarizeAttrValue(iter->second)); + } + + // Consider the device to be a final attr with name "_device". + if (!node_def.device().empty()) { + if (!first) strings::StrAppend(&ret, ", "); + first = false; + strings::StrAppend(&ret, "_device=\"", node_def.device(), "\""); + } + strings::StrAppend(&ret, "]("); + + // Output inputs, including control inputs, verbatim. + first = true; + for (const string& input : node_def.input()) { + if (!first) strings::StrAppend(&ret, ", "); + first = false; + strings::StrAppend(&ret, input); + } + strings::StrAppend(&ret, ")"); + return ret; +} + +const AttrValue* AttrSlice::Find(const string& attr_name) const { + auto iter = attrs_->find(attr_name); + if (iter == attrs_->end()) return nullptr; + return &iter->second; +} + +Status AttrSlice::Find(const string& attr_name, + const AttrValue** attr_value) const { + *attr_value = Find(attr_name); + if (*attr_value != nullptr) { + return Status::OK(); + } + Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:"); + if (ndef_) { + s = AttachDef(s, *ndef_); + } + return s; +} + +// The ... is to allow the caller to inject some value validation code. Use +// just ; if no additional validation code is needed. +#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \ + Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \ + TYPE* value) { \ + const AttrValue* attr_value; \ + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \ + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \ + const auto& v = attr_value->FIELD(); \ + __VA_ARGS__; \ + *value = CAST; \ + return Status::OK(); \ + } \ + Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \ + std::vector* value) { \ + const AttrValue* attr_value; \ + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \ + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \ + for (const auto& v : attr_value->list().FIELD()) { \ + __VA_ARGS__; \ + value->APPEND_OP(CAST); \ + } \ + return Status::OK(); \ + } + +DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;) +DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;) +DEFINE_GET_ATTR(int32, i, "int", emplace_back, static_cast(v), + if (static_cast(static_cast(v)) != v) { + return errors::InvalidArgument("Attr ", attr_name, + " has value ", v, + " out of range for an int32"); + }) +DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;) +// std::vector specialization does not have emplace_back until +// c++14, so we have to use push_back (see +// http://en.cppreference.com/w/cpp/container/vector/emplace_back) +DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;) +DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast(v), + ;) +DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;) +DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v), ;) +DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t; + if (!t.FromProto(v)) { + return errors::InvalidArgument( + "Attr ", attr_name, " has value ", v.ShortDebugString(), + " that can't be converted to a Tensor"); + }) + +#undef DEFINE_GET_ATTR + +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + DataTypeVector* value) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)")); + for (const auto& v : attr_value->list().type()) { + value->push_back(static_cast(v)); + } + return Status::OK(); +} + +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + const TensorProto** value) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor")); + *value = &attr_value->tensor(); + return Status::OK(); +} + +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + const NameAttrList** value) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func")); + *value = &attr_value->func(); + return Status::OK(); +} + +namespace { // Helper for InOutTypesForNode(). + +Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def, + DataTypeVector* sig) { + const int original_size = sig->size(); + if (!arg_def.number_attr().empty()) { + // Same type repeated "repeats" times. + int32 repeats = -1; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.number_attr(), &repeats)); + if (repeats < 0) { + return errors::InvalidArgument("Value for number_attr() ", repeats, + " < 0"); + } + + if (!arg_def.type_attr().empty()) { + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.type_attr(), &dtype)); + for (int i = 0; i < repeats; ++i) { + sig->push_back(dtype); + } + } else if (arg_def.type() != DT_INVALID) { + for (int i = 0; i < repeats; ++i) { + sig->push_back(arg_def.type()); + } + } else { + return errors::InvalidArgument("Missing type or type_attr field in ", + arg_def.ShortDebugString()); + } + } else if (!arg_def.type_attr().empty()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + AttrSlice(node_def).Find(arg_def.type_attr(), &attr_value)); + sig->push_back(attr_value->type()); + } else if (!arg_def.type_list_attr().empty()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value)); + for (int dtype : attr_value->list().type()) { + sig->push_back(static_cast(dtype)); + } + } else if (arg_def.type() != DT_INVALID) { + sig->push_back(arg_def.type()); + } else { + return errors::InvalidArgument("No type fields in ", + arg_def.ShortDebugString()); + } + if (arg_def.is_ref()) { + // For all types that were added by this function call, make them refs. + for (size_t i = original_size; i < sig->size(); ++i) { + (*sig)[i] = MakeRefType((*sig)[i]); + } + } + return Status::OK(); +} + +} // namespace + +Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs, DataTypeVector* outputs) { + for (const auto& arg : op_def.input_arg()) { + TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); + } + for (const auto& arg : op_def.output_arg()) { + TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs)); + } + return Status::OK(); +} + +Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { + if (node_def.op() != op_def.name()) { + return errors::InvalidArgument("NodeDef op '", node_def.op(), + "' does not match ", SummarizeOpDef(op_def), + "; NodeDef: ", SummarizeNodeDef(node_def)); + } + + bool seen_control = false; + size_t num_inputs = 0; + // TODO(josh11b): Unify the input field validation. + for (const string& input : node_def.input()) { + if (StringPiece(input).starts_with("^")) { + seen_control = true; + if (input.find(':') != string::npos) { + return errors::InvalidArgument("Control input '", input, + "' must not have ':' in NodeDef: ", + SummarizeNodeDef(node_def)); + } + } else if (seen_control) { + return errors::InvalidArgument("Non-control input '", input, + "' after control input in NodeDef: ", + SummarizeNodeDef(node_def)); + } else { + ++num_inputs; + } + } + + std::unordered_map op_attrs; + for (const auto& attr : op_def.attr()) { + if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) { + return errors::InvalidArgument("OpDef has duplicate attr name '", + attr.name(), "': ", + SummarizeOpDef(op_def)); + } + } + for (const auto& attr : node_def.attr()) { + // Allow internal optional attributes with names starting with "_". + if (StringPiece(attr.first).starts_with("_")) { + continue; + } + auto iter = op_attrs.find(attr.first); + if (iter == op_attrs.end()) { + return errors::InvalidArgument("NodeDef mentions attr '", attr.first, + "' not in ", SummarizeOpDef(op_def), + "; NodeDef: ", SummarizeNodeDef(node_def)); + } + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ValidateAttrValue(attr.second, *iter->second), "; NodeDef: ", + SummarizeNodeDef(node_def), "; ", SummarizeOpDef(op_def)); + // Keep track of which attr names have (not) been found in the NodeDef. + op_attrs.erase(iter); + } + + // Were all attrs in the OpDef found in the NodeDef? + if (!op_attrs.empty()) { + string attrs; + for (const auto& attr_pair : op_attrs) { + if (!attrs.empty()) strings::StrAppend(&attrs, "', '"); + strings::StrAppend(&attrs, attr_pair.first); + } + return errors::InvalidArgument("NodeDef missing attr", + op_attrs.size() == 1 ? " '" : "s '", attrs, + "' from ", SummarizeOpDef(op_def), + "; NodeDef: ", SummarizeNodeDef(node_def)); + } + + // Validate the number of inputs. + DataTypeVector inputs, outputs; + TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs)); + + if (num_inputs != inputs.size()) { + return errors::InvalidArgument( + "NodeDef expected inputs '", DataTypeVectorString(inputs), + "' do not match ", num_inputs, " inputs specified; ", + SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def)); + } + + return Status::OK(); +} + +namespace { // Helpers for NameRangesForNode() + +Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def, + const OpDef& op_def, int* num) { + if (!arg_def.number_attr().empty()) { + // Same type repeated "num" times. + return GetNodeAttr(node_def, arg_def.number_attr(), num); + } else if (!arg_def.type_list_attr().empty()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value)); + *num = attr_value->list().type_size(); + } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) { + *num = 1; + } else { + return errors::InvalidArgument("Argument '", arg_def.name(), + "' incorrectly specified in op definition: ", + SummarizeOpDef(op_def)); + } + return Status::OK(); +} + +Status NameRangesHelper(const NodeDef& node_def, + const protobuf::RepeatedPtrField& args, + const OpDef& op_def, NameRangeMap* result) { + int start = 0; + int num; + for (const auto& arg : args) { + TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num)); + (*result)[arg.name()] = std::make_pair(start, start + num); + start += num; + } + return Status::OK(); +} + +} // namespace + +Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs) { + TF_RETURN_IF_ERROR( + NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs)); + return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs); +} + +void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) { + for (const auto& attr_def : op_def.attr()) { + AttrSlice attrs(*node_def); + if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) { + AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def); + } + } +} + +namespace { + +static RE2* valid_op_name_pattern = new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"); +static RE2* valid_data_input_pattern = + new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*(\\:(0|([1-9][0-9]*)))?"); +static RE2* valid_control_input_pattern = + new RE2("\\^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"); + +} // namespace + +Status ValidateOpInput(const string& input_name, bool* is_control_input) { + *is_control_input = false; + if (RE2::FullMatch(input_name, *valid_data_input_pattern)) { + return Status::OK(); + } else if (RE2::FullMatch(input_name, *valid_control_input_pattern)) { + *is_control_input = true; + return Status::OK(); + } else { + return errors::InvalidArgument("Illegal op input name '", input_name, "'"); + } +} + +Status ValidateOpName(const string& op_name) { + if (RE2::FullMatch(op_name, *valid_op_name_pattern)) { + return Status::OK(); + } else { + return errors::InvalidArgument("Illegal op name '", op_name, "'"); + } +} + +Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) { + Status s = ValidateOpName(node_def.name()); + if (!s.ok()) { + return AttachDef(s, node_def); + } + bool in_control_inputs = false; + for (const string& input_name : node_def.input()) { + bool is_control_input; + s = ValidateOpInput(input_name, &is_control_input); + if (!s.ok()) { + return AttachDef(s, node_def); + } + + if (in_control_inputs && !is_control_input) { + return AttachDef(errors::InvalidArgument( + "All control inputs must follow all data inputs"), + node_def); + } + in_control_inputs = is_control_input; + } + return Status::OK(); +} + +Status AttachDef(const Status& status, const NodeDef& node_def) { + Status ret = status; + errors::AppendToMessage( + &ret, strings::StrCat(" [[Node: ", SummarizeNodeDef(node_def), "]]")); + return ret; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h new file mode 100644 index 0000000000..fce6fd2433 --- /dev/null +++ b/tensorflow/core/framework/node_def_util.h @@ -0,0 +1,157 @@ +#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_ + +#include +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// Produce a human-readable version of a NodeDef that is more concise +// than a text-format proto. +string SummarizeNodeDef(const NodeDef& node_def); + +typedef protobuf::Map AttrValueMap; + +// Adds an attr with name and value to *node_def. +// The type of the attr is based on the type of value. +template +void AddNodeAttr(const string& name, T&& value, NodeDef* node_def) { + AttrValue attr_value; + SetAttrValue(std::forward(value), &attr_value); + node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value)); +} + +// Version to workaround C++'s "perfect" forwarding not being able to +// forward {...} initialization. +template +void AddNodeAttr(const string& name, std::initializer_list value, + NodeDef* node_def) { + AttrValue attr_value; + SetAttrValue(value, &attr_value); + node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value)); +} + +class AttrSlice { + public: + AttrSlice(const NodeDef& node_def) // NOLINT(runtime/explicit) + : ndef_(&node_def), + attrs_(&ndef_->attr()) {} + + explicit AttrSlice(const AttrValueMap* a) : attrs_(a) {} + + // Returns the attr with attr_name if found. Otherwise, returns + // nullptr. + const AttrValue* Find(const string& attr_name) const; + + // Returns the attr_value for attr_name if found. Otherwise, returns a + // NotFound status. + Status Find(const string& attr_name, const AttrValue** attr_value) const; + + private: + const NodeDef* ndef_ = nullptr; + const AttrValueMap* attrs_; +}; + +// Look up the attr with name attr_name and set *value to its value. If no +// attr with attr_name is found in node_def, or the attr does not have +// a matching type, a non-ok status will be returned. +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + string* value); // type: "string" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + int64* value); // type: "int" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + int32* value); // type: "int" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + float* value); // type: "float" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + bool* value); // type: "bool" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + DataType* value); // type: "type" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + TensorShapeProto* value); // type: "shape" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + TensorShape* value); // type: "shape" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + Tensor* value); // type: "tensor" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector* value); // type "list(string)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector* value); // type "list(int)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector* value); // type "list(int)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector* value); // type "list(float)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector* value); // type "list(bool)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector* value); // type "list(type)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + DataTypeVector* value); // type "list(type)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector* value); // type "list(shape)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector* value); // type "list(shape)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector* value); // type: "list(tensor)" + +// This version avoids copying the TensorProto. +// REQUIRES: Must not use *value beyond the lifetime of node_def. +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + const TensorProto** value); // type: "tensor" + +// This version avoids copying the NameAttrList. +// REQUIRES: Must not use *value beyond the lifetime of node_def. +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + const NameAttrList** value); // type: "func" + +// Computes the input and output types for a specific node, for +// attr-style ops. +// REQUIRES: ValidateOpDef(op_def).ok() +Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs, DataTypeVector* outputs); + +// Validates that the NodeDef: +// * Defines all expected attrs from the OpDef. +// * All attrs satisfies constraints from the OpDef. +// * Has a signature matching SignatureForNode(). +// etc. +Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def); + +// Computes the mapping from input/output argument name to the +// corresponding input/output index range. For example, +// input "foo" coresponds to input indices +// [ (*inputs)["foo"].first, (*inputs)["foo"].second ). +typedef std::unordered_map> NameRangeMap; +Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs); + +// Adds default values to *node_def for unspecified attrs from op_def. +void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def); + +// Validates the syntax of a NodeDef provided externally. +// +// The following is an EBNF-style syntax for NodeDef objects. Note that +// Node objects are actually specified as tensorflow::NodeDef protocol buffers, +// which contain many other fields that are not (currently) validated. +// +// Node = NodeName, Inputs +// Inputs = ( DataInput * ), ( ControlInput * ) +// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ? +// ControlInput = "^", NodeName +// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] * +Status ValidateExternalNodeDefSyntax(const NodeDef& node_def); + +// Returns "status" with kernel's NodeDef attached as additional text +// in the error message. +Status AttachDef(const Status& status, const NodeDef& node_def); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_ diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc new file mode 100644 index 0000000000..71f1760a09 --- /dev/null +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -0,0 +1,442 @@ +#include "tensorflow/core/framework/node_def_util.h" + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include + +namespace tensorflow { +namespace { + +OpDef ToOpDef(const OpDefBuilder& builder) { + OpDef op_def; + EXPECT_OK(builder.Finalize(&op_def)); + return op_def; +} + +NodeDef ToNodeDef(const string& text) { + NodeDef node_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); + return node_def; +} + +NodeDef ToNodeDef(const NodeDefBuilder& builder) { + NodeDef node_def; + EXPECT_OK(builder.Finalize(&node_def)); + return node_def; +} + +void ExpectSuccess(const NodeDef& good, const OpDef& op_def) { + EXPECT_EQ(Status::OK(), ValidateNodeDef(good, op_def)) + << "NodeDef: " << SummarizeNodeDef(good) + << "; OpDef: " << SummarizeOpDef(op_def); +} + +void ExpectFailure(const NodeDef& bad, const OpDef& op_def, + const string& message) { + Status status = ValidateNodeDef(bad, op_def); + + EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad) + << "; OpDef: " << SummarizeOpDef(op_def); + if (status.ok()) return; + + EXPECT_TRUE(errors::IsInvalidArgument(status)) + << status << "; NodeDef: " << SummarizeNodeDef(bad) + << "; OpDef: " << SummarizeOpDef(op_def); + + LOG(INFO) << "Message: " << status.error_message(); + EXPECT_TRUE(StringPiece(status.ToString()).contains(message)) + << "NodeDef: " << SummarizeNodeDef(bad) + << "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status + << "\nDoes not contain: " << message; +} + +TEST(NodeDefUtilTest, In) { + const OpDef op = ToOpDef(OpDefBuilder("In").Input("i: T").Attr("T: type")); + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'In' input:'a' attr { key:'T' value { type:DT_FLOAT } } + )proto"); + ExpectSuccess(node_def, op); + + EXPECT_EQ("n = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def)); + + // Mismatching Op names. + NodeDef bad = node_def; + bad.set_op("Wrong"); + ExpectFailure(bad, op, "NodeDef op 'Wrong' does not match Op= 2") + .Attr("T: {float,double}")); + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'SameIn' input:'a' input:'b' + attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_DOUBLE } } + )proto"); + ExpectSuccess(node_def, op); + + EXPECT_EQ("n = SameIn[N=2, T=DT_DOUBLE](a, b)", SummarizeNodeDef(node_def)); + + // Illegal type + NodeDef bad = ToNodeDef(R"proto( + name:'n' op:'SameIn' input:'a' input:'b' + attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_STRING } } + )proto"); + ExpectFailure(bad, op, + "Value for attr 'T' of string is not in the list of allowed " + "values: float, double"); + + // Too few inputs + bad = ToNodeDef(R"proto( + name:'n' op:'SameIn' input:'a' input:'b' + attr { key:'N' value { i:1 } } attr { key:'T' value { type:DT_FLOAT } } + )proto"); + ExpectFailure(bad, op, "Value for attr 'N' of 1 must be at least minimum 2"); +} + +TEST(NodeDefUtilTest, AnyIn) { + const OpDef op = + ToOpDef(OpDefBuilder("AnyIn").Input("i: T").Attr("T: list(type) >= 1")); + + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectSuccess(node_def, op); + + EXPECT_EQ("n = AnyIn[T=[DT_INT32, DT_STRING]](a, b)", + SummarizeNodeDef(node_def)); + + const NodeDef bad = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' attr { key:'T' value { list { } } } + )proto"); + ExpectFailure(bad, op, "Length for attr 'T' of 0 must be at least minimum 1"); + + // With proto3 semantics, an empty value {} is indistinguishable from a value + // with an empty list in it. So we simply expect to get a message complaining + // about empty list for value {}. + const NodeDef bad2 = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' attr { key:'T' value { } } + )proto"); + ExpectFailure(bad2, op, + "Length for attr 'T' of 0 must be at least minimum 1"); +} + +TEST(NodeDefUtilTest, Device) { + const OpDef op_def1 = ToOpDef(OpDefBuilder("None")); + const NodeDef node_def1 = + ToNodeDef(NodeDefBuilder("d", &op_def1).Device("/cpu:17")); + ExpectSuccess(node_def1, op_def1); + EXPECT_EQ("d = None[_device=\"/cpu:17\"]()", SummarizeNodeDef(node_def1)); + + const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int")); + const NodeDef node_def2 = + ToNodeDef(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5")); + ExpectSuccess(node_def2, op_def2); + EXPECT_EQ("d = WithAttr[v=7, _device=\"/cpu:5\"]()", + SummarizeNodeDef(node_def2)); +} + +void ExpectValidSyntax(const NodeDef& good) { + EXPECT_EQ(Status::OK(), ValidateExternalNodeDefSyntax(good)) + << "NodeDef: " << SummarizeNodeDef(good); +} + +void ExpectInvalidSyntax(const NodeDef& bad, const string& message) { + Status status = ValidateExternalNodeDefSyntax(bad); + + ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad); + + EXPECT_TRUE(errors::IsInvalidArgument(status)) + << status << "; NodeDef: " << SummarizeNodeDef(bad); + + EXPECT_TRUE(StringPiece(status.ToString()).contains(message)) + << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", " + << message; +} + +TEST(NodeDefUtilTest, ValidSyntax) { + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectValidSyntax(node_def); + + const NodeDef node_def_explicit_inputs = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a:0' input:'b:123' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectValidSyntax(node_def_explicit_inputs); + + EXPECT_EQ("n = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)", + SummarizeNodeDef(node_def_explicit_inputs)); + + const NodeDef node_def_control_input = ToNodeDef(R"proto( + name:'n-' op:'AnyIn' input:'a' input:'^b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectValidSyntax(node_def_control_input); + + const NodeDef node_def_invalid_name = ToNodeDef(R"proto( + name:'n:0' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_invalid_name, "Illegal op name 'n:0'"); + + const NodeDef node_def_internal_name = ToNodeDef(R"proto( + name:'_n' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'"); + + const NodeDef node_def_internal_input_name = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'_a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_internal_input_name, + "Illegal op input name '_a'"); + + const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' input:'^b:0' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_invalid_control_input_name, + "Illegal op input name '^b:0'"); + + const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'^a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_data_input_after_control, + "All control inputs must follow all data inputs"); +} + +TEST(NameRangesForNodeTest, Simple) { + const OpDef op_def = ToOpDef(OpDefBuilder("Simple") + .Input("a: float") + .Input("b: int32") + .Output("c: string") + .Output("d: bool")); + NameRangeMap inputs, outputs; + const NodeDef node_def = ToNodeDef( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs); + + EXPECT_EQ("simple = Simple[](a, b)", SummarizeNodeDef(node_def)); + + OpDef bad_op_def = op_def; + bad_op_def.mutable_input_arg(0)->clear_type(); + EXPECT_FALSE(NameRangesForNode(node_def, bad_op_def, &inputs, &outputs).ok()); +} + +TEST(NameRangesForNodeTest, Polymorphic) { + const OpDef op_def = ToOpDef(OpDefBuilder("Polymorphic") + .Input("a: T") + .Input("b: T") + .Output("c: T") + .Attr("T: type")); + NameRangeMap inputs, outputs; + const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("poly", &op_def) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32))); + EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs); + EXPECT_EQ("poly = Polymorphic[T=DT_INT32](a, b)", + SummarizeNodeDef(node_def1)); + + const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("poly", &op_def) + .Input(FakeInput(DT_BOOL)) + .Input(FakeInput(DT_BOOL))); + EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs); + EXPECT_EQ("poly = Polymorphic[T=DT_BOOL](a, b)", SummarizeNodeDef(node_def2)); +} + +TEST(NameRangesForNodeTest, NRepeats) { + const OpDef op_def = ToOpDef(OpDefBuilder("NRepeats") + .Input("a: N * int32") + .Input("b: N * T") + .Output("c: T") + .Output("d: N * string") + .Output("e: M * bool") + .Attr("N: int") + .Attr("M: int") + .Attr("T: type")); + NameRangeMap inputs, outputs; + const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("nr", &op_def) + .Input(FakeInput(4, DT_INT32)) + .Input(FakeInput(4, DT_FLOAT)) + .Attr("M", 3)); + EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}), + outputs); + EXPECT_EQ( + "nr = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, b:2, b:3)", + SummarizeNodeDef(node_def1)); + + const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("nr", &op_def) + .Input(FakeInput(2, DT_INT32)) + .Input(FakeInput(2, DT_DOUBLE)) + .Attr("M", 7)); + EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}), + outputs); + EXPECT_EQ("nr = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)", + SummarizeNodeDef(node_def2)); + + NodeDef bad_node_def = node_def2; + bad_node_def.clear_attr(); + EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok()); +} + +TEST(NameRangesForNodeTest, TypeList) { + const OpDef op_def = ToOpDef(OpDefBuilder("TypeList") + .Input("a: T1") + .Input("b: T2") + .Output("c: T2") + .Output("d: T3") + .Output("e: T1") + .Attr("T1: list(type)") + .Attr("T2: list(type)") + .Attr("T3: list(type)")); + NameRangeMap inputs, outputs; + const NodeDef node_def1 = + ToNodeDef(NodeDefBuilder("tl", &op_def) + .Input(FakeInput({DT_BOOL, DT_FLOAT})) + .Input(FakeInput(4, DT_FLOAT)) + .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING})); + EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}), + outputs); + EXPECT_EQ( + "tl = TypeList[T1=[DT_BOOL, DT_FLOAT]," + " T2=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT]," + " T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)", + SummarizeNodeDef(node_def1)); + + const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("tl", &op_def) + .Input(FakeInput(7, DT_INT32)) + .Input(FakeInput({DT_DOUBLE})) + .Attr("T3", {DT_DOUBLE, DT_STRING})); + EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}), + outputs); + EXPECT_EQ( + "tl = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32," + " DT_INT32, DT_INT32], T2=[DT_DOUBLE], T3=[DT_DOUBLE, DT_STRING]]" + "(a, a:1, a:2, a:3, a:4, a:5, a:6, b)", + SummarizeNodeDef(node_def2)); + + NodeDef bad_node_def = node_def2; + bad_node_def.clear_attr(); + EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/numeric_op.h b/tensorflow/core/framework/numeric_op.h new file mode 100644 index 0000000000..8413d18f33 --- /dev/null +++ b/tensorflow/core/framework/numeric_op.h @@ -0,0 +1,96 @@ +#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_ +#define TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// One input and one output, both the same type. +template +class UnaryOp : public OpKernel { + public: + explicit UnaryOp(OpKernelConstruction* context) : OpKernel(context) { + const DataType dt = DataTypeToEnum::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt})); + } +}; + +// Two inputs and one output, all the same type. +template +class BinaryOp : public OpKernel { + public: + explicit BinaryOp(OpKernelConstruction* context) : OpKernel(context) { + const DataType dt = DataTypeToEnum::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt})); + } +}; + +// For operations where the input and output are the same shape. +// +// For usage, see ../framework/elementwise_ops.cc. +template +class UnaryElementWiseOp : public UnaryOp { + public: + using UnaryOp::UnaryOp; + + void Compute(OpKernelContext* context) override { + // Output shape is the same as input shape. + const Tensor& input = context->input(0); + Tensor* output; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + static_cast(this)->Operate(context, input, output); + } +}; + +// For binary elementwise operations. +template +class BinaryElementWiseOp : public BinaryOp { + public: + using BinaryOp::BinaryOp; + + void Compute(OpKernelContext* context) override { + const Tensor& a = context->input(0); + const Tensor& b = context->input(1); + + if (!context->ValidateInputsAreSameShape(this)) { + return; + } + + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, a.shape(), &output)); + + // Dispatch to the descendant's Operate() function. + switch (a.dims()) { +#define NDIM_CASE(NDIMS) \ + case NDIMS: { \ + static_cast(this)->template Operate(context, a, b, output); \ + break; \ + } + + NDIM_CASE(1); + NDIM_CASE(2); + NDIM_CASE(3); + NDIM_CASE(4); + NDIM_CASE(5); + NDIM_CASE(6); + NDIM_CASE(7); + NDIM_CASE(8); +#undef NDIM_CASE + + default: + context->SetStatus(errors::OutOfRange( + "We only handle up to Tensor::dims() up to 8, not ", a.dims())); + break; + } + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_ diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h new file mode 100644 index 0000000000..366f00ae03 --- /dev/null +++ b/tensorflow/core/framework/numeric_types.h @@ -0,0 +1,15 @@ +#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_ + +#include + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Single precision complex. +typedef std::complex complex64; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_ diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc new file mode 100644 index 0000000000..15b7eab4da --- /dev/null +++ b/tensorflow/core/framework/op.cc @@ -0,0 +1,135 @@ +#include "tensorflow/core/framework/op.h" + +#include +#include +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// OpRegistry ----------------------------------------------------------------- + +OpRegistryInterface::~OpRegistryInterface() {} + +OpRegistry::OpRegistry() : initialized_(false) {} + +void OpRegistry::Register(std::function func) { + mutex_lock lock(mu_); + if (initialized_) { + OpDef def = func(); + TF_QCHECK_OK(RegisterAlreadyLocked(def)) << "Attempting to register: " + << SummarizeOpDef(def); + } else { + deferred_.push_back(func); + } +} + +const OpDef* OpRegistry::LookUp(const string& op_type_name, + Status* status) const { + const OpDef* op_def = nullptr; + bool first_call = false; + { // Scope for lock. + mutex_lock lock(mu_); + first_call = CallDeferred(); + op_def = gtl::FindWithDefault(registry_, op_type_name, nullptr); + // Note: Can't hold mu_ while calling Export() below. + } + if (first_call) { + TF_QCHECK_OK(ValidateKernelRegistrations(this)); + } + if (op_def == nullptr) { + status->Update( + errors::NotFound("Op type not registered '", op_type_name, "'")); + static bool first = true; + if (first) { + OpList op_list; + Export(true, &op_list); + LOG(INFO) << "All registered Ops:"; + for (const auto& op : op_list.op()) { + LOG(INFO) << SummarizeOpDef(op); + } + first = false; + } + } + return op_def; +} + +void OpRegistry::Export(bool include_internal, OpList* ops) const { + mutex_lock lock(mu_); + CallDeferred(); + + std::vector> sorted(registry_.begin(), + registry_.end()); + std::sort(sorted.begin(), sorted.end()); + + auto out = ops->mutable_op(); + out->Clear(); + out->Reserve(sorted.size()); + + for (const auto& item : sorted) { + if (include_internal || !StringPiece(item.first).starts_with("_")) { + *out->Add() = *item.second; + } + } +} + +string OpRegistry::DebugString(bool include_internal) const { + OpList op_list; + Export(include_internal, &op_list); + string ret; + for (const auto& op : op_list.op()) { + strings::StrAppend(&ret, SummarizeOpDef(op), "\n"); + } + return ret; +} + +bool OpRegistry::CallDeferred() const { + if (initialized_) return false; + initialized_ = true; + for (const auto& fn : deferred_) { + OpDef def = fn(); + TF_QCHECK_OK(RegisterAlreadyLocked(def)) << "Attempting to register: " + << SummarizeOpDef(def); + } + deferred_.clear(); + return true; +} + +Status OpRegistry::RegisterAlreadyLocked(const OpDef& def) const { + TF_RETURN_IF_ERROR(ValidateOpDef(def)); + + std::unique_ptr copy(new OpDef(def)); + if (gtl::InsertIfNotPresent(®istry_, def.name(), copy.get())) { + copy.release(); // Ownership transferred to op_registry + return Status::OK(); + } else { + return errors::AlreadyExists("Op with name ", def.name()); + } +} + +// static +OpRegistry* OpRegistry::Global() { + static OpRegistry* global_op_registry = new OpRegistry; + return global_op_registry; +} + +namespace register_op { +OpDefBuilder& RegisterOp(StringPiece name) { + VLOG(1) << "RegisterOp: " << name; + OpDefBuilder* b = new OpDefBuilder(name); + OpRegistry::Global()->Register([b]() -> ::tensorflow::OpDef { + OpDef op_def; + TF_QCHECK_OK(b->Finalize(&op_def)); + delete b; + return op_def; + }); + return *b; +} +} // namespace register_op + +} // namespace tensorflow diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h new file mode 100644 index 0000000000..95ad32df35 --- /dev/null +++ b/tensorflow/core/framework/op.h @@ -0,0 +1,122 @@ +#ifndef TENSORFLOW_FRAMEWORK_OP_H_ +#define TENSORFLOW_FRAMEWORK_OP_H_ + +#include +#include + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Users that want to look up an OpDef by type name should take an +// OpRegistryInterface. Functions accepting a +// (const) OpRegistryInterface* may call LookUp() from multiple threads. +class OpRegistryInterface { + public: + virtual ~OpRegistryInterface(); + + // Returns nullptr and sets *status if no OpDef is registered under that + // name, otherwise returns the registered OpDef. + // Caller must not delete the returned pointer. + virtual const OpDef* LookUp(const string& op_type_name, + Status* status) const = 0; +}; + +// The standard implementation of OpRegistryInterface, along with a +// global singleton used for registering OpDefs via the REGISTER +// macros below. Thread-safe. +// +// Example registration: +// OpRegistry::Global()->Register([]()->OpDef{ +// OpDef def; +// // Populate def here. +// return def; +// }); +class OpRegistry : public OpRegistryInterface { + public: + OpRegistry(); + ~OpRegistry() override {} + + // Calls func() and registers the returned OpDef. Since Register() + // is normally called during program initialization (before main()), + // we defer calling func() until the first call to LookUp() or + // Export() (if one of those has already been called, func() is + // called immediately). + void Register(std::function func); + + const OpDef* LookUp(const string& op_type_name, + Status* status) const override; + + // Fills *ops with all registered OpDefss (except those with names + // starting with '_' if include_internal == false). + void Export(bool include_internal, OpList* ops) const; + + // Returns ASCII-format OpList for all registered OpDefs (except + // those with names starting with '_' if include_internal == false). + string DebugString(bool include_internal) const; + + // A singleton available at startup. + static OpRegistry* Global(); + + private: + // Ensures that all the functions in deferred_ get called, their OpDef's + // registered, and returns with deferred_ empty. Returns true the first + // time it is called. + bool CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Add 'def' to the registry. On failure, or if there is already an + // OpDef with that name registered, returns a non-okay status. + Status RegisterAlreadyLocked(const OpDef& def) const + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutable mutex mu_; + // Functions in deferred_ may only be called with mu_ held. + mutable std::vector> deferred_ GUARDED_BY(mu_); + mutable std::unordered_map registry_ GUARDED_BY(mu_); + mutable bool initialized_ GUARDED_BY(mu_); +}; + +// Support for defining the OpDef (specifying the semantics of the Op and how +// it should be created) and registering it in the OpRegistry::Global() +// registry. Usage: +// +// REGISTER_OP("my_op_name") +// .Attr(":") +// .Attr(":=") +// .Input(":") +// .Input(":Ref()") +// .Output(":") +// .Doc(R"( +// <1-line summary> +// +// : +// : +// )"); +// +// Note: .Doc() should be last. +// For details, see the OpDefBuilder class in op_def_builder.h. + +namespace register_op { +// To call OpRegistry::Global()->Register(...), used by the +// REGISTER_OP macro below. +OpDefBuilder& RegisterOp(StringPiece name); +} // namespace register_op + +#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name) +#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name) +#define REGISTER_OP_UNIQ(ctr, name) \ + static ::tensorflow::OpDefBuilder& register_op##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::register_op::RegisterOp(name) + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_H_ diff --git a/tensorflow/core/framework/op_def.proto b/tensorflow/core/framework/op_def.proto new file mode 100644 index 0000000000..4a2e90b1b9 --- /dev/null +++ b/tensorflow/core/framework/op_def.proto @@ -0,0 +1,142 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/attr_value.proto"; +import "tensorflow/core/framework/types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + // TODO(josh11b): bool is_optional? + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the ..._list, fields of AttrValue). + // If type == "type" or "list(type)" above, then the type_list field + // of allowed_values has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the s_list field has + // the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + // TODO(josh11b): Implement that optimization. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc new file mode 100644 index 0000000000..7d7c07de4c --- /dev/null +++ b/tensorflow/core/framework/op_def_builder.cc @@ -0,0 +1,447 @@ +#include "tensorflow/core/framework/op_def_builder.h" + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +namespace { + +bool RE2Consume(StringPiece* sp, const char* pattern) { + RegexpStringPiece base_sp = ToRegexpStringPiece(*sp); + bool r = RE2::Consume(&base_sp, pattern); + *sp = FromRegexpStringPiece(base_sp); + return r; +} + +bool RE2Consume(StringPiece* sp, const char* pattern, StringPiece* out) { + RegexpStringPiece base_sp = ToRegexpStringPiece(*sp); + RegexpStringPiece base_out; + bool r = RE2::Consume(&base_sp, pattern, &base_out); + *sp = FromRegexpStringPiece(base_sp); + *out = FromRegexpStringPiece(base_out); + return r; +} + +bool RE2Consume(StringPiece* sp, const char* pattern, int64* out) { + RegexpStringPiece base_sp = ToRegexpStringPiece(*sp); + bool r = RE2::Consume(&base_sp, pattern, out); + *sp = FromRegexpStringPiece(base_sp); + return r; +} + +string AttrError(StringPiece orig, const string& op_name) { + return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name); +} + +#define VERIFY(expr, ...) \ + do { \ + if (!(expr)) { \ + errors->push_back( \ + strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \ + return; \ + } \ + } while (false) + +void FinalizeAttr(StringPiece spec, OpDef* op_def, + std::vector* errors) { + OpDef::AttrDef* attr = op_def->add_attr(); + StringPiece orig(spec); + + // Parse ":" at the beginning. + StringPiece tmp_name; + VERIFY(RE2Consume(&spec, "([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*", &tmp_name), + "Trouble parsing ':'"); + attr->set_name(tmp_name.data(), tmp_name.size()); + + // Read "" or "list()". + bool is_list = RE2Consume(&spec, "list\\s*\\(\\s*"); + string type; + if (spec.Consume("string")) { + type = "string"; + } else if (spec.Consume("int")) { + type = "int"; + } else if (spec.Consume("float")) { + type = "float"; + } else if (spec.Consume("bool")) { + type = "bool"; + } else if (spec.Consume("type")) { + type = "type"; + } else if (spec.Consume("shape")) { + type = "shape"; + } else if (spec.Consume("tensor")) { + type = "tensor"; + } else if (spec.Consume("func")) { + type = "func"; + } else if (spec.Consume("numbertype") || spec.Consume("numerictype")) { + type = "type"; + AttrValue* allowed = attr->mutable_allowed_values(); + for (DataType dt : NumberTypes()) { + allowed->mutable_list()->add_type(dt); + } + } else if (spec.Consume("quantizedtype")) { + type = "type"; + AttrValue* allowed = attr->mutable_allowed_values(); + for (DataType dt : QuantizedTypes()) { + allowed->mutable_list()->add_type(dt); + } + } else if (spec.Consume("realnumbertype") || + spec.Consume("realnumerictype")) { + type = "type"; + AttrValue* allowed = attr->mutable_allowed_values(); + for (DataType dt : RealNumberTypes()) { + allowed->mutable_list()->add_type(dt); + } + } else if (spec.Consume("{")) { + // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }" + RE2Consume(&spec, "\\s*"); + AttrValue* allowed = attr->mutable_allowed_values(); + if (spec.starts_with("\"") || spec.starts_with("'")) { + type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }" + while (true) { + StringPiece escaped_string; + VERIFY((RE2Consume(&spec, R"xx("((?:[^"\\]|\\.)*)"\s*)xx", + &escaped_string) || + RE2Consume(&spec, R"xx('((?:[^'\\]|\\.)*)'\s*)xx", + &escaped_string)), + "Trouble parsing allowed string at '", spec, "'"); + string unescaped; + string error; + VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error), + "Trouble unescaping \"", escaped_string, "\", got error: ", + error); + allowed->mutable_list()->add_s(unescaped); + if (spec.Consume(",")) { + RE2Consume(&spec, "\\s*"); + if (spec.Consume("}")) break; // Allow ending with ", }". + } else { + VERIFY(spec.Consume("}"), + "Expected , or } after strings in list, not: '", spec, "'"); + break; + } + } + } else { // "{ int32, float, bool }" + type = "type"; + while (true) { + StringPiece type_string; + VERIFY(RE2Consume(&spec, "([a-z0-9]+)\\s*", &type_string), + "Trouble parsing type string at '", spec, "'"); + DataType dt; + VERIFY(DataTypeFromString(type_string, &dt), + "Unrecognized type string '", type_string, "'"); + allowed->mutable_list()->add_type(dt); + if (spec.Consume(",")) { + RE2Consume(&spec, "\\s*"); + if (spec.Consume("}")) break; // Allow ending with ", }". + } else { + VERIFY(spec.Consume("}"), + "Expected , or } after types in list, not: '", spec, "'"); + break; + } + } + } + } else { + VERIFY(false, "Trouble parsing type string at '", spec, "'"); + } + RE2Consume(&spec, "\\s*"); + + // Write the type into *attr. + if (is_list) { + VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'"); + RE2Consume(&spec, "\\s*"); + attr->set_type(strings::StrCat("list(", type, ")")); + } else { + attr->set_type(type); + } + + // Read optional minimum constraint at the end. + if ((is_list || type == "int") && spec.Consume(">=")) { + int64 min_limit = -999; + VERIFY(RE2Consume(&spec, "\\s*(-?\\d+)\\s*", &min_limit), + "Could not parse integer lower limit after '>=', found '", spec, + "' instead"); + attr->set_has_minimum(true); + attr->set_minimum(min_limit); + } + + // Parse default value, if present. + if (spec.Consume("=")) { + RE2Consume(&spec, "\\s*"); + VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()), + "Could not parse default value '", spec, "'"); + } else { + VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end"); + } +} + +#undef VERIFY + +string InOutError(bool is_output, StringPiece orig, const string& op_name) { + return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig, + "\") for Op ", op_name); +} + +#define VERIFY(expr, ...) \ + do { \ + if (!(expr)) { \ + errors->push_back(strings::StrCat( \ + __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \ + return; \ + } \ + } while (false) + +void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def, + std::vector* errors) { + OpDef::ArgDef* arg = + is_output ? op_def->add_output_arg() : op_def->add_input_arg(); + + StringPiece orig(spec); + + // Parse ":" at the beginning. + StringPiece tmp_name; + VERIFY(RE2Consume(&spec, "([a-z][a-z0-9_]*)\\s*:\\s*", &tmp_name), + "Trouble parsing 'name:'"); + arg->set_name(tmp_name.data(), tmp_name.size()); + + // Detect "Ref(...)". + if (RE2Consume(&spec, "Ref\\s*\\(\\s*")) { + arg->set_is_ref(true); + } + + { // Parse "" or "*". + StringPiece first, second, type_or_attr; + VERIFY(RE2Consume(&spec, "([a-zA-Z][a-zA-Z0-9_]*)\\s*", &first), + "Trouble parsing either a type or an attr name at '", spec, "'"); + if (RE2Consume(&spec, "[*]\\s*([a-zA-Z][a-zA-Z0-9_]*)\\s*", &second)) { + arg->set_number_attr(first.data(), first.size()); + type_or_attr = second; + } else { + type_or_attr = first; + } + DataType dt; + if (DataTypeFromString(type_or_attr, &dt)) { + arg->set_type(dt); + } else { + const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def); + VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'"); + if (attr->type() == "type") { + arg->set_type_attr(type_or_attr.data(), type_or_attr.size()); + } else { + VERIFY(attr->type() == "list(type)", "Reference to attr '", + type_or_attr, "' with type ", attr->type(), + " that isn't type or list(type)"); + arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size()); + } + } + } + + // Closing ) for Ref(. + if (arg->is_ref()) { + VERIFY(RE2Consume(&spec, "\\)\\s*"), + "Did not find closing ')' for 'Ref(', instead found: '", spec, "'"); + } + + // Should not have anything else. + VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end"); + + // Int attrs that are the length of an input or output get a default + // minimum of 1. + if (!arg->number_attr().empty()) { + OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def); + if (attr != nullptr && !attr->has_minimum()) { + attr->set_has_minimum(true); + attr->set_minimum(1); + } + } else if (!arg->type_list_attr().empty()) { + // If an input or output has type specified by a list(type) attr, + // it gets a default minimum of 1 as well. + OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def); + if (attr != nullptr && attr->type() == "list(type)" && + !attr->has_minimum()) { + attr->set_has_minimum(true); + attr->set_minimum(1); + } + } +} + +#undef VERIFY + +int num_leading_spaces(StringPiece s) { + size_t i = 0; + while (i < s.size() && s[i] == ' ') { + ++i; + } + return i; +} + +void FinalizeDoc(const string& text, OpDef* op_def, + std::vector* errors) { + std::vector lines = str_util::Split(text, '\n'); + + // Remove trailing spaces. + for (string& line : lines) { + str_util::StripTrailingWhitespace(&line); + } + + // First non-blank line -> summary. + int l = 0; + while (static_cast(l) < lines.size() && lines[l].empty()) ++l; + if (static_cast(l) < lines.size()) { + op_def->set_summary(lines[l]); + ++l; + } + while (static_cast(l) < lines.size() && lines[l].empty()) ++l; + + // Lines until we see name: -> description. + int start_l = l; + while (static_cast(l) < lines.size() && + !RE2::PartialMatch(lines[l], "^[a-zA-Z][a-zA-Z0-9_]*\\s*:")) { + ++l; + } + int end_l = l; + // Trim trailing blank lines from the description. + while (start_l < end_l && lines[end_l - 1].empty()) --end_l; + string desc = str_util::Join( + gtl::ArraySlice(lines.data() + start_l, end_l - start_l), "\n"); + if (!desc.empty()) op_def->set_description(desc); + + // name: description + // possibly continued on the next line + // if so, we remove the minimum indent + StringPiece name; + std::vector description; + while (static_cast(l) < lines.size()) { + description.clear(); + description.push_back(lines[l]); + RE2Consume(&description.back(), "([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*", &name); + ++l; + while (static_cast(l) < lines.size() && + !RE2::PartialMatch(lines[l], "^[a-zA-Z][a-zA-Z0-9_]*\\s*:")) { + description.push_back(lines[l]); + ++l; + } + // Remove any trailing blank lines. + while (!description.empty() && description.back().empty()) { + description.pop_back(); + } + // Compute the minimum indent of all lines after the first. + int min_indent = -1; + for (size_t i = 1; i < description.size(); ++i) { + if (!description[i].empty()) { + int indent = num_leading_spaces(description[i]); + if (min_indent < 0 || indent < min_indent) min_indent = indent; + } + } + // Remove min_indent spaces from all lines after the first. + for (size_t i = 1; i < description.size(); ++i) { + if (!description[i].empty()) description[i].remove_prefix(min_indent); + } + // Concatenate lines into a single string. + const string complete(str_util::Join(description, "\n")); + + // Find name. + bool found = false; + for (int i = 0; !found && i < op_def->input_arg_size(); ++i) { + if (op_def->input_arg(i).name() == name) { + op_def->mutable_input_arg(i)->set_description(complete); + found = true; + } + } + for (int i = 0; !found && i < op_def->output_arg_size(); ++i) { + if (op_def->output_arg(i).name() == name) { + op_def->mutable_output_arg(i)->set_description(complete); + found = true; + } + } + for (int i = 0; !found && i < op_def->attr_size(); ++i) { + if (op_def->attr(i).name() == name) { + op_def->mutable_attr(i)->set_description(complete); + found = true; + } + } + if (!found) { + errors->push_back( + strings::StrCat("No matching input/output/attr for name '", name, + "' from Doc() for Op ", op_def->name())); + return; + } + } +} + +} // namespace + +OpDefBuilder::OpDefBuilder(StringPiece op_name) { + op_def_.set_name(op_name.ToString()); // NOLINT +} + +OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) { + attrs_.emplace_back(spec.data(), spec.size()); + return *this; +} + +OpDefBuilder& OpDefBuilder::Input(StringPiece spec) { + inputs_.emplace_back(spec.data(), spec.size()); + return *this; +} + +OpDefBuilder& OpDefBuilder::Output(StringPiece spec) { + outputs_.emplace_back(spec.data(), spec.size()); + return *this; +} + +OpDefBuilder& OpDefBuilder::Doc(StringPiece text) { + if (!doc_.empty()) { + errors_.push_back( + strings::StrCat("Extra call to Doc() for Op ", op_def_.name())); + } else { + doc_.assign(text.data(), text.size()); + } + return *this; +} + +OpDefBuilder& OpDefBuilder::SetIsCommutative() { + op_def_.set_is_commutative(true); + return *this; +} + +OpDefBuilder& OpDefBuilder::SetIsAggregate() { + op_def_.set_is_aggregate(true); + return *this; +} + +OpDefBuilder& OpDefBuilder::SetIsStateful() { + op_def_.set_is_stateful(true); + return *this; +} + +OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() { + op_def_.set_allows_uninitialized_input(true); + return *this; +} + +Status OpDefBuilder::Finalize(OpDef* op_def) const { + std::vector errors = errors_; + *op_def = op_def_; + + for (StringPiece attr : attrs_) { + FinalizeAttr(attr, op_def, &errors); + } + for (StringPiece input : inputs_) { + FinalizeInputOrOutput(input, false, op_def, &errors); + } + for (StringPiece output : outputs_) { + FinalizeInputOrOutput(output, true, op_def, &errors); + } + FinalizeDoc(doc_, op_def, &errors); + + if (errors.empty()) return Status::OK(); + return errors::InvalidArgument(str_util::Join(errors, "\n")); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h new file mode 100644 index 0000000000..017338c508 --- /dev/null +++ b/tensorflow/core/framework/op_def_builder.h @@ -0,0 +1,109 @@ +// Class and associated machinery for specifying an Op's OpDef for Op +// registration. + +#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ +#define TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ + +#include +#include +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Builder class passed to the REGISTER_OP() macro. +class OpDefBuilder { + public: + // Constructs an OpDef with just the name field set. + explicit OpDefBuilder(StringPiece op_name); + + // Adds an attr to this OpDefBuilder (and returns *this). The spec has + // format ":" or ":=" + // where matches regexp [a-zA-Z][a-zA-Z0-9_]* + // (by convention only using capital letters for attrs that can be inferred) + // can be: + // "string", "int", "float", "bool", "type", "shape", or "tensor" + // "numbertype", "realnumbertype", "quantizedtype", "{int32,int64}" + // (meaning "type" with a restriction on valid values) + // "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}" + // (meaning "string" with a restriction on valid values) + // "list(string)", ..., "list(tensor)", "list(numbertype)", ... + // (meaning lists of the above types) + // "int >= 2" (meaning "int" with a restriction on valid values) + // "list(string) >= 2", "list(int) >= 2" + // (meaning "list(string)" / "list(int)" with length at least 2) + // , if included, should use the Proto text format + // of . For lists use [a, b, c] format. + // + // Note that any attr specifying the length of an input or output will + // get a default minimum of 1 unless the >= # syntax is used. + // + // TODO(josh11b): Perhaps support restrictions and defaults as optional + // extra arguments to Attr() instead of encoding them in the spec string. + // TODO(josh11b): Would like to have better dtype handling for tensor attrs: + // * Ability to say the type of an input/output matches the type of + // the tensor. + // * Ability to restrict the type of the tensor like the existing + // restrictions for type attrs. + // Perhaps by linking the type of the tensor to a type attr? + OpDefBuilder& Attr(StringPiece spec); + + // Adds an input or ouput to this OpDefBuilder (and returns *this). + // The spec has form ":" or ":Ref()" + // where matches regexp [a-z][a-z0-9_]* and can be: + // * For a single tensor: + // * For a sequence of tensors with the same type: * + // * For a sequence of tensors with different types: + // Where: + // is either one of "float", "int32", "string", ... + // or the name of an attr (see above) with type "type". + // is the name of an attr with type "int". + // is the name of an attr with type "list(type)". + // TODO(josh11b): Indicate Ref() via an optional argument instead of + // in the spec? + // TODO(josh11b): SparseInput() and SparseOutput() matching the Python + // handling? + OpDefBuilder& Input(StringPiece spec); + OpDefBuilder& Output(StringPiece spec); + + // Turns on the indicated boolean flag in this OpDefBuilder (and + // returns *this). + OpDefBuilder& SetIsCommutative(); + OpDefBuilder& SetIsAggregate(); + OpDefBuilder& SetIsStateful(); + OpDefBuilder& SetAllowsUninitializedInput(); + + // Adds docs to this OpDefBuilder (and returns *this). + // Docs have the format: + // <1-line summary> + // + // : + // : + // + // Where is the name of an attr, input, or output. Please + // wrap docs at 72 columns so that it may be indented in the + // generated output. For tensor inputs or outputs (not attrs), you + // may start the description with an "=" (like name:= ) + // to suppress the automatically-generated type documentation in + // generated output. + OpDefBuilder& Doc(StringPiece text); + + // Sets *op_def to the requested OpDef, or returns an error. + // Must be called after all of the above methods. + // Note that OpDefBuilder only reports parsing errors. You should also + // call ValidateOpDef() to detect other problems. + Status Finalize(OpDef* op_def) const; + + private: + OpDef op_def_; + std::vector attrs_; + std::vector inputs_; + std::vector outputs_; + string doc_; + std::vector errors_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc new file mode 100644 index 0000000000..e53bad7075 --- /dev/null +++ b/tensorflow/core/framework/op_def_builder_test.cc @@ -0,0 +1,519 @@ +#include "tensorflow/core/framework/op_def_builder.h" + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include + +namespace tensorflow { +namespace { + +static void CanonicalizeAttrTypeListOrder(OpDef* def) { + for (int i = 0; i < def->attr_size(); i++) { + AttrValue* a = def->mutable_attr(i)->mutable_allowed_values(); + std::sort(a->mutable_list()->mutable_type()->begin(), + a->mutable_list()->mutable_type()->end()); + } +} + +class OpDefBuilderTest : public ::testing::Test { + protected: + OpDefBuilder b() { return OpDefBuilder("Test"); } + + void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto) { + OpDef op_def; + Status status = builder.Finalize(&op_def); + EXPECT_OK(status); + if (status.ok()) { + OpDef expected; + protobuf::TextFormat::ParseFromString( + strings::StrCat("name: 'Test' ", proto), &expected); + // Allow different orderings + CanonicalizeAttrTypeListOrder(&op_def); + CanonicalizeAttrTypeListOrder(&expected); + EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString()); + } + } + + void ExpectOrdered(const OpDefBuilder& builder, StringPiece proto) { + OpDef op_def; + Status status = builder.Finalize(&op_def); + EXPECT_OK(status); + if (status.ok()) { + OpDef expected; + protobuf::TextFormat::ParseFromString( + strings::StrCat("name: 'Test' ", proto), &expected); + EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString()); + } + } + + void ExpectFailure(const OpDefBuilder& builder, string error) { + OpDef op_def; + Status status = builder.Finalize(&op_def); + EXPECT_FALSE(status.ok()); + if (!status.ok()) { + EXPECT_EQ(status.error_message(), error); + } + } +}; + +TEST_F(OpDefBuilderTest, Attr) { + ExpectSuccess(b().Attr("a:string"), "attr: { name: 'a' type: 'string' }"); + ExpectSuccess(b().Attr("A: int"), "attr: { name: 'A' type: 'int' }"); + ExpectSuccess(b().Attr("a1 :float"), "attr: { name: 'a1' type: 'float' }"); + ExpectSuccess(b().Attr("a_a : bool"), "attr: { name: 'a_a' type: 'bool' }"); + ExpectSuccess(b().Attr("aB : type"), "attr: { name: 'aB' type: 'type' }"); + ExpectSuccess(b().Attr("aB_3\t: shape"), + "attr: { name: 'aB_3' type: 'shape' }"); + ExpectSuccess(b().Attr("t: tensor"), "attr: { name: 't' type: 'tensor' }"); + ExpectSuccess(b().Attr("XYZ\t:\tlist(type)"), + "attr: { name: 'XYZ' type: 'list(type)' }"); + ExpectSuccess(b().Attr("f: func"), "attr { name: 'f' type: 'func'}"); +} + +TEST_F(OpDefBuilderTest, AttrFailure) { + ExpectFailure( + b().Attr("_:string"), + "Trouble parsing ':' from Attr(\"_:string\") for Op Test"); + ExpectFailure( + b().Attr("9:string"), + "Trouble parsing ':' from Attr(\"9:string\") for Op Test"); + ExpectFailure(b().Attr(":string"), + "Trouble parsing ':' from Attr(\":string\") for Op Test"); + ExpectFailure(b().Attr("string"), + "Trouble parsing ':' from Attr(\"string\") for Op Test"); + ExpectFailure(b().Attr("a:invalid"), + "Trouble parsing type string at 'invalid' from " + "Attr(\"a:invalid\") for Op Test"); + ExpectFailure( + b().Attr("b:"), + "Trouble parsing type string at '' from Attr(\"b:\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, AttrWithRestrictions) { + // Types with restrictions. + ExpectSuccess(b().Attr("a:numbertype"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " + "DT_INT8, DT_COMPLEX64, DT_QINT8, DT_QUINT8, DT_QINT32] } } }"); + ExpectSuccess(b().Attr("a:realnumbertype"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " + "DT_INT8] } } }"); + ExpectSuccess(b().Attr("a:quantizedtype"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_QINT8, DT_QUINT8, DT_QINT32] } } }"); + ExpectSuccess(b().Attr("a:{string,int32}"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_STRING, DT_INT32] } } }"); + ExpectSuccess(b().Attr("a: { float , complex64 } "), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_FLOAT, DT_COMPLEX64] } } }"); + ExpectSuccess(b().Attr("a: {float, complex64,} "), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_FLOAT, DT_COMPLEX64] } }"); + ExpectSuccess(b().Attr(R"(a: { "X", "yz" })"), + "attr: { name: 'a' type: 'string' allowed_values { list { s: " + "['X', 'yz'] } } }"); + ExpectSuccess(b().Attr(R"(a: { "X", "yz", })"), + "attr: { name: 'a' type: 'string' allowed_values { list { s: " + "['X', 'yz'] } } }"); + ExpectSuccess( + b().Attr("i: int >= -5"), + "attr: { name: 'i' type: 'int' has_minimum: true minimum: -5 }"); +} + +TEST_F(OpDefBuilderTest, AttrRestrictionFailure) { + ExpectFailure( + b().Attr("a:{}"), + "Trouble parsing type string at '}' from Attr(\"a:{}\") for Op Test"); + ExpectFailure( + b().Attr("a:{,}"), + "Trouble parsing type string at ',}' from Attr(\"a:{,}\") for Op Test"); + ExpectFailure(b().Attr("a:{invalid}"), + "Unrecognized type string 'invalid' from Attr(\"a:{invalid}\") " + "for Op Test"); + ExpectFailure(b().Attr("a:{\"str\", float}"), + "Trouble parsing allowed string at 'float}' from " + "Attr(\"a:{\"str\", float}\") for Op Test"); + ExpectFailure(b().Attr("a:{ float, \"str\" }"), + "Trouble parsing type string at '\"str\" }' from Attr(\"a:{ " + "float, \"str\" }\") for Op Test"); + ExpectFailure(b().Attr("a:{float,,string}"), + "Trouble parsing type string at ',string}' from " + "Attr(\"a:{float,,string}\") for Op Test"); + ExpectFailure(b().Attr("a:{float,,}"), + "Trouble parsing type string at ',}' from " + "Attr(\"a:{float,,}\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, AttrListOfRestricted) { + ExpectSuccess( + b().Attr("a:list(realnumbertype)"), + "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " + "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " + "DT_INT8] } } }"); + ExpectSuccess( + b().Attr("a:list(quantizedtype)"), + "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " + "[DT_QINT8, DT_QUINT8, DT_QINT32] } } }"); + ExpectSuccess( + b().Attr("a: list({float, string, bool})"), + "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " + "[DT_FLOAT, DT_STRING, DT_BOOL] } } }"); + ExpectSuccess( + b().Attr(R"(a: list({ "one fish", "two fish" }))"), + "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: " + "['one fish', 'two fish'] } } }"); + ExpectSuccess( + b().Attr(R"(a: list({ 'red fish', 'blue fish' }))"), + "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: " + "['red fish', 'blue fish'] } } }"); + ExpectSuccess( + b().Attr(R"(a: list({ "single' ", 'double"' }))"), + "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: " + "[\"single' \", 'double\"'] } } }"); + ExpectSuccess( + b().Attr(R"(a: list({ 'escape\'\n', "from\\\"NY" }))"), + "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: " + "[\"escape'\\n\", 'from\\\\\"NY'] } } }"); +} + +TEST_F(OpDefBuilderTest, AttrListWithMinLength) { + ExpectSuccess( + b().Attr("i: list(bool) >= 4"), + "attr: { name: 'i' type: 'list(bool)' has_minimum: true minimum: 4 }"); +} + +TEST_F(OpDefBuilderTest, AttrWithDefaults) { + ExpectSuccess(b().Attr(R"(a:string="foo")"), + "attr: { name: 'a' type: 'string' default_value { s:'foo' } }"); + ExpectSuccess(b().Attr(R"(a:string='foo')"), + "attr: { name: 'a' type: 'string' default_value { s:'foo' } }"); + ExpectSuccess(b().Attr("a:float = 1.25"), + "attr: { name: 'a' type: 'float' default_value { f: 1.25 } }"); + ExpectSuccess(b().Attr("a:tensor = { dtype: DT_INT32 int_val: 5 }"), + "attr: { name: 'a' type: 'tensor' default_value { tensor {" + " dtype: DT_INT32 int_val: 5 } } }"); + ExpectSuccess(b().Attr("a:shape = { dim { size: 3 } dim { size: 4 } }"), + "attr: { name: 'a' type: 'shape' default_value { shape {" + " dim { size: 3 } dim { size: 4 } } } }"); +} + +TEST_F(OpDefBuilderTest, AttrFailedDefaults) { + ExpectFailure(b().Attr(R"(a:int="foo")"), + "Could not parse default value '\"foo\"' from " + "Attr(\"a:int=\"foo\"\") for Op Test"); + ExpectFailure(b().Attr("a:float = [1.25]"), + "Could not parse default value '[1.25]' from Attr(\"a:float = " + "[1.25]\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, AttrListWithDefaults) { + ExpectSuccess(b().Attr(R"(a:list(string)=["foo", "bar"])"), + "attr: { name: 'a' type: 'list(string)' " + "default_value { list { s: ['foo', 'bar'] } } }"); + ExpectSuccess(b().Attr("a:list(bool)=[true, false, true]"), + "attr: { name: 'a' type: 'list(bool)' " + "default_value { list { b: [true, false, true] } } }"); + ExpectSuccess(b().Attr(R"(a:list(int)=[0, -1, 2, -4, 8])"), + "attr: { name: 'a' type: 'list(int)' " + "default_value { list { i: [0, -1, 2, -4, 8] } } }"); +} + +TEST_F(OpDefBuilderTest, AttrFailedListDefaults) { + ExpectFailure(b().Attr(R"(a:list(int)=["foo"])"), + "Could not parse default value '[\"foo\"]' from " + "Attr(\"a:list(int)=[\"foo\"]\") for Op Test"); + ExpectFailure(b().Attr(R"(a:list(int)=[7, "foo"])"), + "Could not parse default value '[7, \"foo\"]' from " + "Attr(\"a:list(int)=[7, \"foo\"]\") for Op Test"); + ExpectFailure(b().Attr("a:list(float) = [[1.25]]"), + "Could not parse default value '[[1.25]]' from " + "Attr(\"a:list(float) = [[1.25]]\") for Op Test"); + ExpectFailure(b().Attr("a:list(float) = 1.25"), + "Could not parse default value '1.25' from " + "Attr(\"a:list(float) = 1.25\") for Op Test"); + ExpectFailure(b().Attr(R"(a:list(string)='foo')"), + "Could not parse default value ''foo'' from " + "Attr(\"a:list(string)='foo'\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, InputOutput) { + ExpectSuccess(b().Input("a: int32"), + "input_arg: { name: 'a' type: DT_INT32 }"); + ExpectSuccess(b().Output("b: string"), + "output_arg: { name: 'b' type: DT_STRING }"); + ExpectSuccess(b().Input("c: float "), + "input_arg: { name: 'c' type: DT_FLOAT }"); + ExpectSuccess(b().Output("d: Ref(bool)"), + "output_arg: { name: 'd' type: DT_BOOL is_ref: true }"); + ExpectOrdered(b().Input("a: bool") + .Output("c: complex64") + .Input("b: int64") + .Output("d: string"), + "input_arg: { name: 'a' type: DT_BOOL } " + "input_arg: { name: 'b' type: DT_INT64 } " + "output_arg: { name: 'c' type: DT_COMPLEX64 } " + "output_arg: { name: 'd' type: DT_STRING }"); +} + +TEST_F(OpDefBuilderTest, PolymorphicInputOutput) { + ExpectSuccess(b().Input("a: foo").Attr("foo: type"), + "input_arg: { name: 'a' type_attr: 'foo' } " + "attr: { name: 'foo' type: 'type' }"); + ExpectSuccess(b().Output("a: foo").Attr("foo: { bool, int32 }"), + "output_arg: { name: 'a' type_attr: 'foo' } " + "attr: { name: 'foo' type: 'type' " + "allowed_values: { list { type: [DT_BOOL, DT_INT32] } } }"); +} + +TEST_F(OpDefBuilderTest, InputOutputListSameType) { + ExpectSuccess(b().Input("a: n * int32").Attr("n: int"), + "input_arg: { name: 'a' number_attr: 'n' type: DT_INT32 } " + "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 }"); + // Polymorphic case: + ExpectSuccess(b().Output("b: n * foo").Attr("n: int").Attr("foo: type"), + "output_arg: { name: 'b' number_attr: 'n' type_attr: 'foo' } " + "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 } " + "attr: { name: 'foo' type: 'type' }"); +} + +TEST_F(OpDefBuilderTest, InputOutputListAnyType) { + ExpectSuccess( + b().Input("c: foo").Attr("foo: list(type)"), + "input_arg: { name: 'c' type_list_attr: 'foo' } " + "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 }"); + ExpectSuccess( + b().Output("c: foo").Attr("foo: list({string, float})"), + "output_arg: { name: 'c' type_list_attr: 'foo' } " + "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 " + "allowed_values: { list { type: [DT_STRING, DT_FLOAT] } } }"); +} + +TEST_F(OpDefBuilderTest, InputOutputFailure) { + ExpectFailure(b().Input("9: int32"), + "Trouble parsing 'name:' from Input(\"9: int32\") for Op Test"); + ExpectFailure( + b().Output("_: int32"), + "Trouble parsing 'name:' from Output(\"_: int32\") for Op Test"); + ExpectFailure(b().Input(": int32"), + "Trouble parsing 'name:' from Input(\": int32\") for Op Test"); + ExpectFailure(b().Output("int32"), + "Trouble parsing 'name:' from Output(\"int32\") for Op Test"); + ExpectFailure( + b().Input("CAPS: int32"), + "Trouble parsing 'name:' from Input(\"CAPS: int32\") for Op Test"); + ExpectFailure(b().Input("a: _"), + "Trouble parsing either a type or an attr name at '_' from " + "Input(\"a: _\") for Op Test"); + ExpectFailure(b().Input("a: 9"), + "Trouble parsing either a type or an attr name at '9' from " + "Input(\"a: 9\") for Op Test"); + ExpectFailure(b().Input("a: 9 * int32"), + "Trouble parsing either a type or an attr name at '9 * int32' " + "from Input(\"a: 9 * int32\") for Op Test"); + ExpectFailure( + b().Input("a: x * _").Attr("x: type"), + "Extra '* _' unparsed at the end from Input(\"a: x * _\") for Op Test"); + ExpectFailure(b().Input("a: x * y extra").Attr("x: int").Attr("y: type"), + "Extra 'extra' unparsed at the end from Input(\"a: x * y " + "extra\") for Op Test"); + ExpectFailure(b().Input("a: Ref(int32"), + "Did not find closing ')' for 'Ref(', instead found: '' from " + "Input(\"a: Ref(int32\") for Op Test"); + ExpectFailure(b().Input("a: Ref(x y").Attr("x: type"), + "Did not find closing ')' for 'Ref(', instead found: 'y' from " + "Input(\"a: Ref(x y\") for Op Test"); + ExpectFailure( + b().Input("a: x"), + "Reference to unknown attr 'x' from Input(\"a: x\") for Op Test"); + ExpectFailure( + b().Input("a: x * y").Attr("x: int"), + "Reference to unknown attr 'y' from Input(\"a: x * y\") for Op Test"); + ExpectFailure(b().Input("a: x").Attr("x: int"), + "Reference to attr 'x' with type int that isn't type or " + "list(type) from Input(\"a: x\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, Set) { + ExpectSuccess(b().SetIsStateful(), "is_stateful: true"); + ExpectSuccess(b().SetIsCommutative().SetIsAggregate(), + "is_commutative: true is_aggregate: true"); +} + +TEST_F(OpDefBuilderTest, DocUnpackSparseFeatures) { + ExpectOrdered(b().Input("sf: string") + .Output("indices: int32") + .Output("ids: int64") + .Output("weights: float") + .Doc(R"doc( +Converts a vector of strings with dist_belief::SparseFeatures to tensors. + +Note that indices, ids and weights are vectors of the same size and have +one-to-one correspondence between their elements. ids and weights are each +obtained by sequentially concatenating sf[i].id and sf[i].weight, for i in +1...size(sf). Note that if sf[i].weight is not set, the default value for the +weight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were +extracted from sf[i], then index[j] is set to i. + +sf: vector of string, where each element is the string encoding of + SparseFeatures proto. +indices: vector of indices inside sf +ids: vector of id extracted from the SparseFeatures proto. +weights: vector of weight extracted from the SparseFeatures proto. +)doc"), + R"proto( +input_arg { + name: "sf" + description: "vector of string, where each element is the string encoding of\nSparseFeatures proto." + type: DT_STRING +} +output_arg { + name: "indices" + description: "vector of indices inside sf" + type: DT_INT32 +} +output_arg { + name: "ids" + description: "vector of id extracted from the SparseFeatures proto." + type: DT_INT64 +} +output_arg { + name: "weights" + description: "vector of weight extracted from the SparseFeatures proto." + type: DT_FLOAT +} +summary: "Converts a vector of strings with dist_belief::SparseFeatures to tensors." +description: "Note that indices, ids and weights are vectors of the same size and have\none-to-one correspondence between their elements. ids and weights are each\nobtained by sequentially concatenating sf[i].id and sf[i].weight, for i in\n1...size(sf). Note that if sf[i].weight is not set, the default value for the\nweight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were\nextracted from sf[i], then index[j] is set to i." +)proto"); +} + +TEST_F(OpDefBuilderTest, DocConcat) { + ExpectOrdered(b().Input("concat_dim: int32") + .Input("values: num_values * dtype") + .Output("output: dtype") + .Attr("dtype: type") + .Attr("num_values: int >= 2") + .Doc(R"doc( +Concatenate N Tensors along one dimension. + +concat_dim: The (scalar) dimension along which to concatenate. Must be + in the range [0, rank(values...)). +values: The N Tensors to concatenate. Their ranks and types must match, + and their sizes must match in all dimensions except concat_dim. +output: A Tensor with the concatenation of values stacked along the + concat_dim dimension. This Tensor's shape matches the Tensors in + values, except in concat_dim where it has the sum of the sizes. +)doc"), + R"proto( +input_arg { + name: "concat_dim" + description: "The (scalar) dimension along which to concatenate. Must be\nin the range [0, rank(values...))." + type: DT_INT32 +} +input_arg { + name: "values" + description: "The N Tensors to concatenate. Their ranks and types must match,\nand their sizes must match in all dimensions except concat_dim." + type_attr: "dtype" + number_attr: "num_values" +} +output_arg { + name: "output" + description: "A Tensor with the concatenation of values stacked along the\nconcat_dim dimension. This Tensor\'s shape matches the Tensors in\nvalues, except in concat_dim where it has the sum of the sizes." + type_attr: "dtype" +} +summary: "Concatenate N Tensors along one dimension." +attr { + name: "dtype" + type: "type" +} +attr { + name: "num_values" + type: "int" + has_minimum: true + minimum: 2 +} +)proto"); +} + +TEST_F(OpDefBuilderTest, DocAttr) { + ExpectOrdered(b().Attr("i: int").Doc(R"doc( +Summary + +i: How much to operate. +)doc"), + R"proto( +summary: "Summary" +attr { + name: "i" + type: "int" + description: "How much to operate." +} +)proto"); +} + +TEST_F(OpDefBuilderTest, DocCalledTwiceFailure) { + ExpectFailure(b().Doc("What's").Doc("up, doc?"), + "Extra call to Doc() for Op Test"); +} + +TEST_F(OpDefBuilderTest, DocFailureMissingName) { + ExpectFailure( + b().Input("a: int32").Doc(R"doc( +Summary + +a: Something for a. +b: b is not defined. +)doc"), + "No matching input/output/attr for name 'b' from Doc() for Op Test"); + + ExpectFailure( + b().Input("a: int32").Doc(R"doc( +Summary + +b: b is not defined and by itself. +)doc"), + "No matching input/output/attr for name 'b' from Doc() for Op Test"); +} + +TEST_F(OpDefBuilderTest, DefaultMinimum) { + ExpectSuccess(b().Input("values: num_values * dtype") + .Output("output: anything") + .Attr("anything: list(type)") + .Attr("dtype: type") + .Attr("num_values: int"), + R"proto( +input_arg { + name: "values" + type_attr: "dtype" + number_attr: "num_values" +} +output_arg { + name: "output" + type_list_attr: "anything" +} +attr { + name: "anything" + type: "list(type)" + has_minimum: true + minimum: 1 +} +attr { + name: "dtype" + type: "type" +} +attr { + name: "num_values" + type: "int" + has_minimum: true + minimum: 1 +} +)proto"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc new file mode 100644 index 0000000000..e3aef011de --- /dev/null +++ b/tensorflow/core/framework/op_def_util.cc @@ -0,0 +1,344 @@ +#include "tensorflow/core/framework/op_def_util.h" + +#include +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { +namespace { // ------ Helper functions ------ + +bool HasAttrStyleType(const OpDef::ArgDef& arg) { + return arg.type() != DT_INVALID || !arg.type_attr().empty() || + !arg.type_list_attr().empty(); +} + +Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { + const AttrValue& allowed_values(attr.allowed_values()); + for (auto allowed : allowed_values.list().type()) { + if (dt == allowed) { + return Status::OK(); + } + } + string allowed_str; + for (int i = 0; i < allowed_values.list().type_size(); ++i) { + if (!allowed_str.empty()) { + strings::StrAppend(&allowed_str, ", "); + } + strings::StrAppend(&allowed_str, + DataTypeString(allowed_values.list().type(i))); + } + return errors::InvalidArgument( + "Value for attr '", attr.name(), "' of ", DataTypeString(dt), + " is not in the list of allowed values: ", allowed_str); +} + +Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { + const AttrValue& allowed_values(attr.allowed_values()); + for (auto allowed : allowed_values.list().s()) { + if (str == allowed) { + return Status::OK(); + } + } + string allowed_str; + for (const string& allowed : allowed_values.list().s()) { + if (!allowed_str.empty()) { + strings::StrAppend(&allowed_str, ", "); + } + strings::StrAppend(&allowed_str, "\"", allowed, "\""); + } + return errors::InvalidArgument( + "Value for attr '", attr.name(), "' of \"", str, + "\" is not in the list of allowed values: ", allowed_str); +} + +} // namespace + +// Requires: attr has already been validated. +Status ValidateAttrValue(const AttrValue& attr_value, + const OpDef::AttrDef& attr) { + // Is it a valid value? + TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()), + " for attr '", attr.name(), "'"); + + // Does the value satisfy the minimum constraint in the AttrDef? + if (attr.has_minimum()) { + if (attr.type() == "int") { + if (attr_value.i() < attr.minimum()) { + return errors::InvalidArgument( + "Value for attr '", attr.name(), "' of ", attr_value.i(), + " must be at least minimum ", attr.minimum()); + } + } else { + int length = -1; + if (attr.type() == "list(string)") { + length = attr_value.list().s_size(); + } else if (attr.type() == "list(int)") { + length = attr_value.list().i_size(); + } else if (attr.type() == "list(float)") { + length = attr_value.list().f_size(); + } else if (attr.type() == "list(bool)") { + length = attr_value.list().b_size(); + } else if (attr.type() == "list(type)") { + length = attr_value.list().type_size(); + } else if (attr.type() == "list(shape)") { + length = attr_value.list().shape_size(); + } else if (attr.type() == "list(tensor)") { + length = attr_value.list().tensor_size(); + } + if (length < attr.minimum()) { + return errors::InvalidArgument( + "Length for attr '", attr.name(), "' of ", length, + " must be at least minimum ", attr.minimum()); + } + } + } + + // Does the value satisfy the allowed_value constraint in the AttrDef? + if (attr.has_allowed_values()) { + if (attr.type() == "type") { + TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr)); + } else if (attr.type() == "list(type)") { + for (int dt : attr_value.list().type()) { + TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast(dt), attr)); + } + } else if (attr.type() == "string") { + TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr)); + } else if (attr.type() == "list(string)") { + for (const string& str : attr_value.list().s()) { + TF_RETURN_IF_ERROR(AllowedStringValue(str, attr)); + } + } else { + return errors::Unimplemented( + "Support for allowed_values not implemented for type ", attr.type()); + } + } + return Status::OK(); +} + +const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) { + for (int i = 0; i < op_def.attr_size(); ++i) { + if (op_def.attr(i).name() == name) { + return &op_def.attr(i); + } + } + return nullptr; +} + +OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) { + for (int i = 0; i < op_def->attr_size(); ++i) { + if (op_def->attr(i).name() == name) { + return op_def->mutable_attr(i); + } + } + return nullptr; +} + +#define VALIDATE(EXPR, ...) \ + do { \ + if (!(EXPR)) { \ + return errors::InvalidArgument(__VA_ARGS__, "; in OpDef: ", \ + op_def.ShortDebugString()); \ + } \ + } while (false) + +static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, + bool output, std::set* names) { + const string suffix = strings::StrCat( + output ? " for output '" : " for input '", arg.name(), "'"); + VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), "Duplicate name: ", + arg.name()); + VALIDATE(HasAttrStyleType(arg), "Missing type", suffix); + + if (!arg.number_attr().empty()) { + const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def); + VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'", + suffix); + VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length", + suffix, " has type ", attr->type(), " != int"); + VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length", + suffix, " must have minimum"); + VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length", + suffix, " must have minimum >= 0"); + VALIDATE(arg.type_list_attr().empty(), + "Can't have both number_attr and type_list_attr", suffix); + VALIDATE((arg.type() != DT_INVALID ? 1 : 0) + + (!arg.type_attr().empty() ? 1 : 0) == + 1, + "Exactly one of type, type_attr must be set", suffix); + } else { + const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) + + (!arg.type_attr().empty() ? 1 : 0) + + (!arg.type_list_attr().empty() ? 1 : 0); + VALIDATE(num_type_fields == 1, + "Exactly one of type, type_attr, type_list_attr must be set", + suffix); + } + + if (!arg.type_attr().empty()) { + const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def); + VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'", + suffix); + VALIDATE(attr->type() == "type", "Attr '", attr->name(), + "' used as type_attr", suffix, " has type ", attr->type(), + " != type"); + } else if (!arg.type_list_attr().empty()) { + const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def); + VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'", + suffix); + VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(), + "' used as type_list_attr", suffix, " has type ", attr->type(), + " != list(type)"); + } else { + // All argument types should be non-reference types at this point. + // ArgDef.is_ref is set to true for reference arguments. + VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '", + DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix); + } + + return Status::OK(); +} + +Status ValidateOpDef(const OpDef& op_def) { + VALIDATE(RE2::FullMatch(op_def.name(), "(?:_.*|[A-Z][a-zA-Z0-9]*)"), + "Invalid name: ", op_def.name(), " (Did you use CamelCase?)"); + + std::set names; // for detecting duplicate names + for (const auto& attr : op_def.attr()) { + // Validate name + VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), "Duplicate name: ", + attr.name()); + DataType dt; + VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ", + attr.name(), " that matches a data type"); + + // Validate type + StringPiece type(attr.type()); + bool is_list = type.Consume("list("); + bool found = false; + for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape", + "tensor", "func"}) { + if (type.Consume(valid)) { + found = true; + break; + } + } + VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(), + "'"); + if (is_list) { + VALIDATE(type.Consume(")"), "'list(' is missing ')' in attr ", + attr.name(), "'s type ", attr.type()); + } + VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ", + attr.name(), "'s type ", attr.type()); + + // Validate minimum + if (attr.has_minimum()) { + VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(), + "' has minimum for unsupported type ", attr.type()); + if (is_list) { + VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(), + "' with list type must have a non-negative minimum, not ", + attr.minimum()); + } + } else { + VALIDATE(attr.minimum() == 0, "Attr '", attr.name(), + "' with has_minimum = false but minimum ", attr.minimum(), + " not equal to default of 0"); + } + + // Validate allowed_values + if (attr.has_allowed_values()) { + const string list_type = + is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")"); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + AttrValueHasType(attr.allowed_values(), list_type), " for attr '", + attr.name(), "' in Op '", op_def.name(), "'"); + } + + // Validate default_value (after we have validated the rest of the attr, + // so we can use ValidateAttrValue()). + if (attr.has_default_value()) { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ValidateAttrValue(attr.default_value(), attr), " in Op '", + op_def.name(), "'"); + } + } + + for (const auto& arg : op_def.input_arg()) { + TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names)); + } + + for (const auto& arg : op_def.output_arg()) { + TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names)); + } + + return Status::OK(); +} + +#undef VALIDATE + +namespace { + +string SummarizeArgs(const protobuf::RepeatedPtrField& args) { + string ret; + for (const OpDef::ArgDef& arg : args) { + if (!ret.empty()) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, arg.name(), ":"); + if (arg.is_ref()) strings::StrAppend(&ret, "Ref("); + if (!arg.number_attr().empty()) { + strings::StrAppend(&ret, arg.number_attr(), "*"); + } + if (arg.type() != DT_INVALID) { + strings::StrAppend(&ret, DataTypeString(arg.type())); + } else { + strings::StrAppend(&ret, arg.type_attr()); + } + if (arg.is_ref()) strings::StrAppend(&ret, ")"); + } + return ret; +} + +} // namespace + +string SummarizeOpDef(const OpDef& op_def) { + string ret = strings::StrCat("Op ", SummarizeArgs(op_def.output_arg())); + for (int i = 0; i < op_def.attr_size(); ++i) { + strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":", + op_def.attr(i).type()); + if (op_def.attr(i).has_default_value()) { + strings::StrAppend(&ret, ",default=", + SummarizeAttrValue(op_def.attr(i).default_value())); + } + if (op_def.attr(i).has_minimum()) { + strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum()); + } + if (op_def.attr(i).has_allowed_values()) { + strings::StrAppend(&ret, ",allowed=", + SummarizeAttrValue(op_def.attr(i).allowed_values())); + } + } + if (op_def.is_commutative()) { + strings::StrAppend(&ret, "; is_commutative=true"); + } + if (op_def.is_aggregate()) { + strings::StrAppend(&ret, "; is_aggregate=true"); + } + if (op_def.is_stateful()) { + strings::StrAppend(&ret, "; is_stateful=true"); + } + if (op_def.allows_uninitialized_input()) { + strings::StrAppend(&ret, "; allows_uninitialized_input=true"); + } + strings::StrAppend(&ret, ">"); + return ret; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h new file mode 100644 index 0000000000..a9fecf3fa0 --- /dev/null +++ b/tensorflow/core/framework/op_def_util.h @@ -0,0 +1,32 @@ +// TODO(josh11b): Probably not needed for OpKernel authors, so doesn't +// need to be as publicly accessible as other files in framework/. + +#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_ + +#include +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Performs a consistency check across the fields of the op_def. +Status ValidateOpDef(const OpDef& op_def); + +// Validates that attr_value satisfies the type and constraints from attr. +// REQUIRES: attr has already been validated. +Status ValidateAttrValue(const AttrValue& attr_value, + const OpDef::AttrDef& attr); + +// The following search through op_def for an attr with the indicated name. +// Returns nullptr if no such attr is found. +const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def); +OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def); + +// Produce a human-readable version of an op_def that is more concise +// than a text-format proto. Excludes descriptions. +string SummarizeOpDef(const OpDef& op_def); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_ diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc new file mode 100644 index 0000000000..515e8bb288 --- /dev/null +++ b/tensorflow/core/framework/op_def_util_test.cc @@ -0,0 +1,330 @@ +#include "tensorflow/core/framework/op_def_util.h" + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include + +namespace tensorflow { +namespace { + +OpDef FromText(const string& text) { + OpDef op_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &op_def)); + return op_def; +} + +class ValidateOpDefTest : public ::testing::Test { + protected: + Status TestProto(const string& text) { + return ValidateOpDef(FromText(text)); + } + + Status TestBuilder(const OpDefBuilder& builder) { + OpDef op_def; + Status status = builder.Finalize(&op_def); + EXPECT_OK(status); + if (!status.ok()) { + return status; + } else { + return ValidateOpDef(op_def); + } + } + + void ExpectFailure(const Status& status, const string& message) { + EXPECT_FALSE(status.ok()) << "Did not see error with: " << message; + if (!status.ok()) { + LOG(INFO) << "message: " << status; + EXPECT_TRUE(StringPiece(status.ToString()).contains(message)) + << "Actual: " << status << "\nExpected to contain: " << message; + } + } +}; + +TEST_F(ValidateOpDefTest, OpDefValid) { + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Input("a: int32"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Output("a: bool"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("t: type").Input("a: t"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int = 3"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5 = 3"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: numbertype"))); + EXPECT_OK(TestBuilder(OpDefBuilder("Uppercase"))); +} + +TEST_F(ValidateOpDefTest, InvalidName) { + ExpectFailure(TestBuilder(OpDefBuilder("lower").Attr("a: int")), + "Invalid name"); + ExpectFailure(TestBuilder(OpDefBuilder("BadSuffix 7%")), "Invalid name"); +} + +TEST_F(ValidateOpDefTest, DuplicateName) { + ExpectFailure( + TestBuilder(OpDefBuilder("DupeName").Input("a: int32").Input("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder( + OpDefBuilder("DupeName").Input("a: int32").Output("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder( + OpDefBuilder("DupeName").Output("a: int32").Output("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder(OpDefBuilder("DupeName").Input("a: int32").Attr("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder(OpDefBuilder("DupeName").Output("a: int32").Attr("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder(OpDefBuilder("DupeName").Attr("a: int").Attr("a: float")), + "Duplicate name: a"); +} + +TEST_F(ValidateOpDefTest, BadAttrName) { + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude").Attr("int32: int")), + "Attr can't have name int32 that matches a data type"); + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude").Attr("float: string")), + "Attr can't have name float that matches a data type"); +} + +TEST_F(ValidateOpDefTest, BadAttrType) { + ExpectFailure( + TestProto("name: 'BadAttrType' attr { name: 'a' type: 'illegal' }"), + "Unrecognized type"); + ExpectFailure( + TestProto("name: 'BadAttrType' attr { name: 'a' type: 'list(illegal)' }"), + "Unrecognized type"); + ExpectFailure( + TestProto("name: 'BadAttrType' attr { name: 'a' type: 'int extra' }"), + "Extra ' extra' at the end"); + ExpectFailure( + TestProto( + "name: 'BadAttrType' attr { name: 'a' type: 'list(int extra)' }"), + "'list(' is missing ')' in attr"); + ExpectFailure( + TestProto( + "name: 'BadAttrType' attr { name: 'a' type: 'list(int) extra' }"), + "Extra ' extra' at the end"); +} + +TEST_F(ValidateOpDefTest, BadAttrDefault) { + ExpectFailure( + TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'int' default_value { s: 'x' } }"), + "AttrValue had value with type string when int expected\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'int' default_value { f: 0.5 } }"), + "AttrValue had value with type float when int expected\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure( + TestProto("name: 'BadAttrDef' attr { name: 'a' type: 'int' " + "default_value { i: 5 list { i: [2] } } }"), + "AttrValue had value with type list(int) when int expected\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure( + TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'list(int)' default_value { f: 0.5 } }"), + "AttrValue had value with type float when list(int) expected\n\t " + "for attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure( + TestProto("name: 'BadAttrDef' attr { name: 'a' type: 'list(int)' " + "default_value { list { i: [5] f: [0.5] } } }"), + "AttrValue had value with type list(float) when list(int) " + "expected\n\t for attr 'a'\n\t in Op 'BadAttrDef'"); + + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'type' default_value { } }"), + "AttrValue missing value with expected type type\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'shape' default_value { } }"), + "AttrValue missing value with expected type shape\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'tensor' default_value { } }"), + "AttrValue missing value with expected type tensor\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + + // default_value {} is indistinguishable from default_value{ list{} } (one + // with an empty list) in proto3 semantics. + EXPECT_OK( + TestProto("name: 'GoodAttrDef' attr { name: 'a' " + "type: 'list(int)' default_value { } }")); + + // Empty lists are allowed: + EXPECT_OK( + TestProto("name: 'GoodAttrDef' attr { name: 'a' " + "type: 'list(int)' default_value { list { } } }")); + // Builder should make the same proto: + EXPECT_OK(TestBuilder(OpDefBuilder("GoodAttrDef").Attr("a: list(int) = []"))); + + // Unless there is a minimum length specified: + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'list(int)' has_minimum: true minimum: 2 " + "default_value { list { } } }"), + "Length for attr 'a' of 0 must be at least minimum 2\n\t in Op " + "'BadAttrDef'"); + ExpectFailure( + TestBuilder(OpDefBuilder("GoodAttrDef").Attr("a: list(bool) >=2 = []")), + "Length for attr 'a' of 0 must be at least minimum 2\n\t in Op " + "'GoodAttrDef'"); + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' type: " + "'list(string)' has_minimum: true minimum: 2 " + "default_value { list { s: ['foo'] } } }"), + "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op " + "'BadAttrDef'"); + ExpectFailure(TestBuilder(OpDefBuilder("GoodAttrDef") + .Attr("a: list(type) >=2 = [DT_STRING]")), + "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op " + "'GoodAttrDef'"); +} + +TEST_F(ValidateOpDefTest, NoRefTypes) { + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrDef").Input("i: float_ref")), + "Illegal use of ref type 'float_ref'. " + "Use 'Ref(type)' instead for input 'i'"); + ExpectFailure( + TestBuilder(OpDefBuilder("BadAttrDef").Attr("T: type = DT_INT32_REF")), + "AttrValue must not have reference type value of int32_ref"); + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrDef") + .Attr("T: list(type) = [DT_STRING_REF]")), + "AttrValue must not have reference type value of string_ref"); +} + +TEST_F(ValidateOpDefTest, BadAttrMin) { + ExpectFailure(TestProto("name: 'BadAttrMin' attr { name: 'a' type: 'string' " + "has_minimum: true minimum: 0 }"), + "minimum for unsupported type string"); + ExpectFailure( + TestProto("name: 'BadAttrMin' attr { name: 'a' type: 'int' default_value " + "{ i: 2 } has_minimum: true minimum: 7 }"), + "Value for attr 'a' of 2 must be at least minimum 7\n\t in Op " + "'BadAttrMin'"); + ExpectFailure( + TestProto("name: 'BadAttrMin' attr { name: 'a' " + "type: 'list(string)' has_minimum: true minimum: -5 }"), + "list type must have a non-negative minimum, not -5"); + EXPECT_OK( + TestProto("name: 'GoodAttrMin' attr { name: 'a' type: 'list(string)' " + "has_minimum: true minimum: 1 }")); + ExpectFailure(TestProto("name: 'NoHasMin' attr { name: 'a' " + "type: 'list(string)' minimum: 3 }"), + "Attr 'a' with has_minimum = false but minimum 3 not equal to " + "default of 0"); +} + +TEST_F(ValidateOpDefTest, BadAttrAllowed) { + // Is in list of allowed types. + EXPECT_OK(TestBuilder( + OpDefBuilder("GoodAttrtude").Attr("x: numbertype = DT_INT32"))); + // Not in list of allowed types. + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: numbertype = DT_STRING")), + "attr 'x' of string is not in the list of allowed values"); + ExpectFailure( + TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: list(realnumbertype) = [DT_COMPLEX64]")), + "attr 'x' of complex64 is not in the list of allowed values"); + // Is in list of allowed strings. + EXPECT_OK(TestBuilder( + OpDefBuilder("GoodAttrtude").Attr("x: {'foo', 'bar'} = 'bar'"))); + // Not in list of allowed strings. + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: {'foo', 'bar'} = 'baz'")), + "attr 'x' of \"baz\" is not in the list of allowed values"); + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: list({'foo', 'bar'}) = ['baz']")), + "attr 'x' of \"baz\" is not in the list of allowed values"); + ExpectFailure(TestProto( + "name: 'BadAttrtude' attr { name: 'a' " + "type: 'string' allowed_values { s: 'not list' } }"), + "with type string when list(string) expected"); + ExpectFailure( + TestProto("name: 'BadAttrtude' attr { name: 'a' " + "type: 'string' allowed_values { list { i: [6] } } }"), + "with type list(int) when list(string) expected"); +} + +TEST_F(ValidateOpDefTest, BadArgType) { + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' " + "type: DT_INT32 } input_arg { name: 'b' }"), + "Missing type for input 'b'"); + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' " + "type: DT_INT32 } output_arg { name: 'b' }"), + "Missing type for output 'b'"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' type: " + "DT_INT32 type_attr: 'x' } attr { name: 'x' type: 'type' }"), + "Exactly one of type, type_attr, type_list_attr must be set for input " + "'a'"); + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_attr: 'x' } attr { name: 'x' type: 'int' }"), + "Attr 'x' used as type_attr for input 'a' has type int"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_attr: 'x' } attr { name: 'x' type: 'list(type)' }"), + "Attr 'x' used as type_attr for input 'a' has type list(type)"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_list_attr: 'x' } attr { name: 'x' type: 'int' }"), + "Attr 'x' used as type_list_attr for input 'a' has type int"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_list_attr: 'x' } attr { name: 'x' type: 'type' }"), + "Attr 'x' used as type_list_attr for input 'a' has type type"); + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_attr: 'x' }"), + "No attr with name 'x' for input 'a'"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' number_attr: 'n' " + "type_attr: 'x' } attr { name: 'x' type: 'list(type)' } " + "attr { name: 'n' type: 'int' has_minimum: true minimum: 1 }"), + "Attr 'x' used as type_attr for input 'a' has type list(type)"); + // But list(type) is fine as the type of an arg without a number_attr: + EXPECT_OK(TestProto( + "name: 'Arg' input_arg { name: 'a' type_list_attr: 'x' } " + "attr { name: 'x' type: 'list(type)' } attr { name: 'n' type: 'int' " + "has_minimum: true minimum: 1 }")); + + // number_attr + EXPECT_OK(TestProto( + "name: 'Arg' input_arg { name: 'a' type: DT_INT32 number_attr: 'n' } " + "attr { name: 'n' type: 'int' has_minimum: true minimum: 0 }")); + + ExpectFailure(TestProto("name: 'Arg' input_arg { name: 'a' type: DT_INT32 " + "number_attr: 'n' }"), + "No attr with name 'n'"); + ExpectFailure( + TestProto( + "name: 'Arg' input_arg { name: 'a' type: " + "DT_INT32 number_attr: 'n' } attr { name: 'n' type: 'string' }"), + "Attr 'n' used as length for input 'a' has type string"); + ExpectFailure( + TestProto("name: 'Arg' input_arg { name: 'a' type: " + "DT_INT32 number_attr: 'n' } attr { name: 'n' type: 'int' }"), + "Attr 'n' used as length for input 'a' must have minimum;"); + ExpectFailure( + TestProto("name: 'Arg' input_arg { name: 'a' type: DT_INT32 number_attr: " + "'n' } attr { name: 'n' type: 'int' has_minimum: true minimum: " + "-5 }"), + "Attr 'n' used as length for input 'a' must have minimum >= 0;"); + ExpectFailure( + TestProto("name: 'Arg' input_arg { name: 'a' number_attr: 'n' } attr { " + "name: 'n' type: 'int' has_minimum: true minimum: 2 }"), + "Missing type for input 'a'; in OpDef:"); + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' number_attr: " + "'n' type_list_attr: 'x' } attr { name: 'n' type: " + "'int' has_minimum: true minimum: 1 } attr { name: " + "'x' type: 'list(type)' }"), + "Can't have both number_attr and type_list_attr for input 'a'"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc new file mode 100644 index 0000000000..04f4b7cacd --- /dev/null +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -0,0 +1,55 @@ +#include "tensorflow/core/framework/op_gen_lib.h" + +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +string WordWrap(StringPiece prefix, StringPiece str, int width) { + const string indent_next_line = "\n" + Spaces(prefix.size()); + width -= prefix.size(); + string result; + strings::StrAppend(&result, prefix); + + while (!str.empty()) { + if (static_cast(str.size()) <= width) { + // Remaining text fits on one line. + strings::StrAppend(&result, str); + break; + } + auto space = str.rfind(' ', width); + if (space == StringPiece::npos) { + // Rather make a too-long line and break at a space. + space = str.find(' '); + if (space == StringPiece::npos) { + strings::StrAppend(&result, str); + break; + } + } + // Breaking at character at position . + StringPiece to_append = str.substr(0, space); + str.remove_prefix(space + 1); + // Remove spaces at break. + while (to_append.ends_with(" ")) { + to_append.remove_suffix(1); + } + while (str.Consume(" ")) { + } + + // Go on to the next line. + strings::StrAppend(&result, to_append); + if (!str.empty()) strings::StrAppend(&result, indent_next_line); + } + + return result; +} + +bool ConsumeEquals(StringPiece* description) { + if (description->Consume("=")) { + while (description->Consume(" ")) { // Also remove spaces after "=". + } + return true; + } + return false; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h new file mode 100644 index 0000000000..9890f1bcec --- /dev/null +++ b/tensorflow/core/framework/op_gen_lib.h @@ -0,0 +1,24 @@ +#ifndef TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_ +#define TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_ + +#include +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +inline string Spaces(int n) { return string(n, ' '); } + +// Wrap prefix + str to be at most width characters, indenting every line +// after the first by prefix.size() spaces. Intended use case is something +// like prefix = " Foo(" and str is a list of arguments (terminated by a ")"). +// TODO(josh11b): Option to wrap on ", " instead of " " when possible. +string WordWrap(StringPiece prefix, StringPiece str, int width); + +// Looks for an "=" at the beginning of *description. If found, strips it off +// (and any following spaces) from *description and return true. Otherwise +// returns false. +bool ConsumeEquals(StringPiece* description); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_ diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc new file mode 100644 index 0000000000..eb83d393f0 --- /dev/null +++ b/tensorflow/core/framework/op_kernel.cc @@ -0,0 +1,749 @@ +#include "tensorflow/core/framework/op_kernel.h" + +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +namespace { + +Status MatchSignatureHelper(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs, + const DataTypeSlice inputs, + const DataTypeSlice outputs) { + bool signature_mismatch = false; + + if (inputs.size() != expected_inputs.size()) signature_mismatch = true; + for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) { + if (!TypesCompatible(expected_inputs[i], inputs[i])) { + signature_mismatch = true; + } + } + + if (outputs.size() != expected_outputs.size()) signature_mismatch = true; + for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) { + if (!TypesCompatible(expected_outputs[i], outputs[i])) { + signature_mismatch = true; + } + } + + if (signature_mismatch) { + return errors::InvalidArgument("Signature mismatch, have: ", + DataTypeSliceString(inputs), "->", + DataTypeSliceString(outputs), " expected: ", + DataTypeSliceString(expected_inputs), "->", + DataTypeSliceString(expected_outputs)); + } + return Status::OK(); +} + +// Check HostMemory backward compatibility. +bool CheckHostMemoryCompatibility(const DeviceType device_type, + const OpKernel* kernel) { + if (device_type == DEVICE_GPU) { + for (int i = 0; i < kernel->num_inputs(); ++i) { + if (kernel->input_type(i) == DT_INT32 && + kernel->input_memory_types()[i] != HOST_MEMORY) { + return false; + } + } + for (int i = 0; i < kernel->num_outputs(); ++i) { + if (kernel->output_type(i) == DT_INT32 && + kernel->output_memory_types()[i] != HOST_MEMORY) { + return false; + } + } + } + return true; +} + +} // namespace + +// OpKernel ------------------------------------------------------------------ + +OpKernel::OpKernel(OpKernelConstruction* context) + : def_(context->def()), + input_types_(context->input_types().begin(), + context->input_types().end()), + output_types_(context->output_types().begin(), + context->output_types().end()), + input_name_map_(context->num_inputs()), + output_name_map_(context->num_outputs()) { + OP_REQUIRES_OK(context, + NameRangesForNode(def_, context->op_def(), &input_name_map_, + &output_name_map_)); + + // By default, the input and output memory types are always in device memory, + // but can be overridden by individual implementations of OpKernels in their + // constructor. + input_memory_types_ = MemoryTypeVector(input_types_.size(), DEVICE_MEMORY); + output_memory_types_ = MemoryTypeVector(output_types_.size(), DEVICE_MEMORY); + // TODO(yuanbyu): For now we assume the memory types of function + // inputs/outputs to be DEVICE_MEMORY. + auto lib = context->function_library(); + if (lib == nullptr || !lib->IsDefined(def_.op())) { + OP_REQUIRES_OK(context, MemoryTypesForNode( + context->device_type(), def_, context->op_def(), + input_name_map_, output_name_map_, + &input_memory_types_, &output_memory_types_)); + // Log all the uses of int32 on GPU. + // TODO(yunabyu): Remove once everyone transitions to HostMemory. + if (VLOG_IS_ON(2)) { + if (!CheckHostMemoryCompatibility(context->device_type(), this)) { + VLOG(2) << "Using int32 on GPU at node: " << SummarizeNodeDef(def()); + } + } + } +} + +Status OpKernel::InputRange(const string& input_name, int* start, + int* stop) const { + const auto result = input_name_map_.find(input_name); + if (result == input_name_map_.end()) { + return errors::InvalidArgument("Unknown input name: ", input_name); + } else { + *start = result->second.first; + *stop = result->second.second; + return Status::OK(); + } +} + +Status OpKernel::OutputRange(const string& output_name, int* start, + int* stop) const { + const auto result = output_name_map_.find(output_name); + if (result == output_name_map_.end()) { + return errors::InvalidArgument("Unknown output name: ", output_name); + } else { + *start = result->second.first; + *stop = result->second.second; + return Status::OK(); + } +} + +void AsyncOpKernel::Compute(OpKernelContext* context) { + Notification n; + ComputeAsync(context, [&n]() { n.Notify(); }); + n.WaitForNotification(); +} + +// PersistentTensor ---------------------------------------------------------- + +Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) { + // the caller has to have a valid context + CHECK(context); + return &tensor_; +} + +Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) { + context->NotifyUseOfPersistentTensor(tensor_); + return &tensor_; +} + +// OpKernelConstruction ------------------------------------------------------ + +Status OpKernelConstruction::MatchSignature( + const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) { + return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_, + output_types_); +} + +Status OpKernelConstruction::allocate_temp(DataType type, + const TensorShape& shape, + Tensor* out_temp) { + Tensor new_temp(allocator_, type, shape); + + if (!new_temp.IsInitialized() && shape.num_elements() > 0) { + return errors::ResourceExhausted( + "OOM when allocating temporary tensor with shape", shape.DebugString()); + } + *out_temp = new_temp; + return Status::OK(); +} + +Status OpKernelConstruction::allocate_persistent( + DataType type, const TensorShape& shape, PersistentTensor* out_persistent, + Tensor** out_tensor) { + // for now just do the same thing as allocate_temp + // TODO(misard) add specific memory tracking for persistent tensors + Tensor persistent; + Status s = allocate_temp(type, shape, &persistent); + if (!s.ok()) { + return s; + } + *out_persistent = PersistentTensor(persistent); + Tensor* allocated = out_persistent->AccessTensor(this); + if (out_tensor) { + *out_tensor = allocated; + } + return s; +} + +// OpKernelContext ----------------------------------------------------------- + +OpKernelContext::OpKernelContext(const Params& params) + : params_(params), + outputs_(params.op_kernel->output_types().size()), + output_allocation_types_(params.op_kernel->output_types().size()) { + Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); + eigen_gpu_device_ = params_.device->MakeGpuDevice(params_.op_device_context, + eigen_gpu_allocator); +} + +OpKernelContext::~OpKernelContext() { + for (TensorValue& value : outputs_) { + if (!value.is_ref()) { + delete value.tensor; + } + } + for (Tensor* t : temp_tensors_) delete t; + delete eigen_gpu_device_; +} + +Status OpKernelContext::input(const string& name, const Tensor** tensor) const { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was " + "expected"); + } + if ((*params_.inputs)[start].is_ref()) { + return errors::InvalidArgument("OpKernel used ref input name '", name, + "' when immutable input was expected"); + } + *tensor = (*params_.inputs)[start].tensor; + return Status::OK(); +} + +Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was expected"); + } + *out_mutex = input_ref_mutex(start); + return Status::OK(); +} + +Status OpKernelContext::mutable_input(const string& name, Tensor* tensor, + bool lock_held) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was expected"); + } + if (!(*params_.inputs)[start].is_ref()) { + return errors::InvalidArgument("OpKernel used immutable input name '", name, + "' when ref input was expected"); + } + // return a copy of the Ref acquired while holding the mutex + if (lock_held) { + *tensor = *(*params_.inputs)[start].tensor; + } else { + mutex_lock l(*input_ref_mutex(start)); + *tensor = *(*params_.inputs)[start].tensor; + } + return Status::OK(); +} + +Status OpKernelContext::replace_ref_input(const string& name, + const Tensor& tensor, + bool lock_held) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was expected"); + } + if (!(*params_.inputs)[start].is_ref()) { + return errors::InvalidArgument("OpKernel used immutable input name '", name, + "' when ref input was expected"); + } + replace_ref_input(start, tensor, lock_held); + return Status::OK(); +} + +Status OpKernelContext::input_list(const string& name, + OpInputList* list) const { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + *list = OpInputList(this, start, stop); + return Status::OK(); +} + +Status OpKernelContext::mutable_input_list(const string& name, + OpMutableInputList* list) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + *list = OpMutableInputList(this, start, stop); + return Status::OK(); +} + +Status OpKernelContext::output_list(const string& name, OpOutputList* list) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + *list = OpOutputList(this, start, stop); + return Status::OK(); +} + +Status OpKernelContext::allocate_output(const string& name, + const TensorShape& shape, + Tensor** tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + return allocate_output(start, shape, tensor); +} + +Status OpKernelContext::allocate_output(const string& name, + const TensorShape& shape, + Tensor** tensor, + AllocatorAttributes attr) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + return allocate_output(start, shape, tensor, attr); +} + +Status OpKernelContext::set_output(const string& name, const Tensor& tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + set_output(start, tensor); + return Status::OK(); +} + +Status OpKernelContext::set_output_ref(const string& name, mutex* mu, + Tensor* tensor_for_ref) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + set_output_ref(start, mu, tensor_for_ref); + return Status::OK(); +} + +Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + *tensor = mutable_output(start); + return Status::OK(); +} + +Status OpKernelContext::release_output(const string& name, TensorValue* value) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + *value = release_output(start); + return Status::OK(); +} + +bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { + const auto& inputs = *params_.inputs; + for (size_t i = 1; i < inputs.size(); ++i) { + if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) { + SetStatus(errors::InvalidArgument( + "Inputs to operation ", op->name(), " of type ", op->type_string(), + " must have the same size and shape. Input 0: ", + inputs[0]->shape().DebugString(), " != input ", i, ": ", + inputs[i]->shape().DebugString())); + return false; + } + } + return true; +} + +Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs) { + DataTypeVector inputs; + for (const TensorValue& t : *params_.inputs) { + inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype()); + } + DataTypeVector outputs = params_.op_kernel->output_types(); + return MatchSignatureHelper(expected_inputs, expected_outputs, inputs, + outputs); +} + +// OpKernel registration ------------------------------------------------------ + +struct KernelRegistration { + KernelRegistration(const KernelDef& d, + kernel_factory::OpKernelRegistrar::Factory f) + : def(d), factory(f) {} + const KernelDef def; + const kernel_factory::OpKernelRegistrar::Factory factory; +}; + +// This maps from 'op_type' + DeviceType to the set of KernelDefs and +// factory functions for instantiating the OpKernel that matches the +// KernelDef. +typedef std::unordered_multimap KernelRegistry; + +static KernelRegistry* GlobalKernelRegistry() { + static KernelRegistry* global_kernel_registry = new KernelRegistry; + return global_kernel_registry; +} + +static string Key(const string& op_type, DeviceType device_type, + const string& label) { + return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":", + label); +} + +namespace kernel_factory { + +OpKernelRegistrar::OpKernelRegistrar(const KernelDef* kernel_def, + Factory factory) { + const string key = + Key(kernel_def->op(), DeviceType(kernel_def->device_type()), + kernel_def->label()); + GlobalKernelRegistry()->insert( + std::make_pair(key, KernelRegistration(*kernel_def, factory))); + delete kernel_def; +} + +} // namespace kernel_factory + +namespace { + +// Helper for AttrsMatch(). +bool InTypeList(DataType dt, const AttrValue& type_list) { + for (int in_list : type_list.list().type()) { + if (dt == in_list) return true; + } + return false; +} + +// Returns whether the attrs in the NodeDef satisfy the constraints in +// the kernel_def. Returns an error if attrs in kernel_def are not +// found, or have a mismatching type. +Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def, + bool* match) { + *match = false; + AttrSlice attrs(node_def); + for (const auto& constraint : kernel_def.constraint()) { + if (constraint.allowed_values().list().type_size() == 0) { + return errors::Unimplemented( + "KernelDef '", kernel_def.ShortDebugString(), + " has constraint on attr '", constraint.name(), + "' with unsupported type: ", + SummarizeAttrValue(constraint.allowed_values())); + } + + const AttrValue* found = attrs.Find(constraint.name()); + if (found) { + if (found->type() != DT_INVALID) { + if (!InTypeList(found->type(), constraint.allowed_values())) { + return Status::OK(); + } + } else { + if (!AttrValueHasType(*found, "list(type)").ok()) { + return errors::InvalidArgument( + "KernelDef '", kernel_def.ShortDebugString(), + "' has constraint on attr '", constraint.name(), + "' that has value '", SummarizeAttrValue(*found), + "' that does not have type 'type' or 'list(type)' in NodeDef '", + SummarizeNodeDef(node_def), "'"); + } + + for (int t : found->list().type()) { + if (!InTypeList(static_cast(t), + constraint.allowed_values())) { + return Status::OK(); + } + } + } + } else { + return errors::InvalidArgument( + "OpKernel '", kernel_def.op(), "' has constraint on attr '", + constraint.name(), "' not in NodeDef '", SummarizeNodeDef(node_def), + "', KernelDef: '", kernel_def.ShortDebugString(), "'"); + } + } + *match = true; + return Status::OK(); +} + +Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def, + const KernelRegistration** reg) { + *reg = nullptr; + string label; // Label defaults to empty if not found in NodeDef. + GetNodeAttr(node_def, "_kernel", &label); + const string key = Key(node_def.op(), device_type, label); + auto regs = GlobalKernelRegistry()->equal_range(key); + for (auto iter = regs.first; iter != regs.second; ++iter) { + // If there is a kernel registered for the op and device_type, + // check that the attrs match. + bool match; + TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match)); + if (match) { + if (*reg != nullptr) { + return errors::InvalidArgument( + "Multiple OpKernel registrations match NodeDef '", + SummarizeNodeDef(node_def), "': '", (*reg)->def.ShortDebugString(), + "' and '", iter->second.def.ShortDebugString(), "'"); + } + *reg = &iter->second; + } + } + return Status::OK(); +} + +} // namespace + +Status SupportedDeviceTypesForNode( + const std::vector& prioritized_types, const NodeDef& def, + DeviceTypeVector* device_types) { + // TODO(zhifengc): Changes the callers (SimplePlacer and + // DynamicPlacer) to consider the possibility that 'def' is call to + // a user-defined function and only calls this + // SupportedDeviceTypesForNode for primitive ops. + Status s; + const OpDef* op_def = OpRegistry::Global()->LookUp(def.op(), &s); + if (op_def) { + for (const DeviceType& device_type : prioritized_types) { + const KernelRegistration* reg = nullptr; + TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, def, ®)); + if (reg != nullptr) device_types->push_back(device_type); + } + } else { + // Assumes that all device types support this node. + for (const DeviceType& device_type : prioritized_types) { + device_types->push_back(device_type); + } + } + return Status::OK(); +} + +std::unique_ptr CreateOpKernel(DeviceType device_type, + DeviceBase* device, + Allocator* allocator, + const NodeDef& node_def, + Status* status) { + OpKernel* kernel = nullptr; + *status = CreateOpKernel(device_type, device, allocator, nullptr, node_def, + &kernel); + return std::unique_ptr(kernel); +} + +Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + const NodeDef& node_def, OpKernel** kernel) { + VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); + + // Look up the Op registered for this op name. + Status s; + const OpDef* op_def = OpRegistry::Global()->LookUp(node_def.op(), &s); + if (op_def == nullptr) return s; + + // Validate node_def against OpDef. + s = ValidateNodeDef(node_def, *op_def); + if (!s.ok()) return s; + + // Look up kernel registration. + const KernelRegistration* registration; + s = FindKernelRegistration(device_type, node_def, ®istration); + if (!s.ok()) { + errors::AppendToMessage(&s, " when instantiating ", node_def.op()); + return s; + } + if (registration == nullptr) { + s.Update(errors::NotFound("No registered '", node_def.op(), + "' OpKernel for ", DeviceTypeString(device_type), + " devices compatible with node ", + SummarizeNodeDef(node_def))); + return s; + } + + // Get signature from the OpDef & NodeDef + DataTypeVector inputs; + DataTypeVector outputs; + s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs)); + if (!s.ok()) { + errors::AppendToMessage(&s, " for node: ", SummarizeNodeDef(node_def)); + return s; + } + + // Everything needed for OpKernel construction. + OpKernelConstruction context(device_type, device, allocator, &node_def, + op_def, flib, inputs, outputs, &s); + *kernel = (*registration->factory)(&context); + if (!s.ok()) { + delete *kernel; + *kernel = nullptr; + } + return s; +} + +namespace { // Helper for MemoryTypesForNode. +// Fills memory_types for either input or output, setting everything +// to DEVICE_MEMORY except those args in host_memory_args. Removes +// elements of host_memory_args that were used. +void MemoryTypesHelper(const NameRangeMap& name_map, + std::vector* host_memory_args, + MemoryTypeVector* memory_types) { + // Set total to the largest endpoint of anything in the name_map. + int total = 0; + for (const auto& item : name_map) { + total = std::max(total, item.second.second); + } + + // Now that we know the size, fill with the default 'DEVICE_MEMORY'. + memory_types->clear(); + memory_types->resize(total, DEVICE_MEMORY); + + // Update args that have been marked as in "HOST_MEMORY". + size_t keep = 0; + for (size_t i = 0; i < host_memory_args->size(); ++i) { + auto iter = name_map.find((*host_memory_args)[i]); + if (iter != name_map.end()) { + for (int j = iter->second.first; j < iter->second.second; ++j) { + (*memory_types)[j] = HOST_MEMORY; + } + } else { + // (*host_memory_args)[i] not found, save it for the next pass. + if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i]; + ++keep; + } + } + host_memory_args->resize(keep); +} +} // namespace + +Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef, + const OpDef& op_def, + const NameRangeMap& input_name_map, + const NameRangeMap& output_name_map, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types) { + Status status; + const KernelRegistration* registration; + TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, ndef, ®istration)); + + if (registration != nullptr) { + const auto& from_proto = registration->def.host_memory_arg(); + std::vector host_memory_args(from_proto.begin(), from_proto.end()); + MemoryTypesHelper(input_name_map, &host_memory_args, input_memory_types); + MemoryTypesHelper(output_name_map, &host_memory_args, output_memory_types); + if (!host_memory_args.empty()) { + return errors::InvalidArgument( + "HostMemory args '", str_util::Join(host_memory_args, "', '"), + "' not found in OpDef: ", SummarizeOpDef(op_def)); + } + } + return status; +} + +Status MemoryTypesForNode(const OpRegistryInterface* op_registry, + DeviceType device_type, const NodeDef& ndef, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types) { + // Look up the Op registered for this op name. + Status status; + const OpDef* op_def = op_registry->LookUp(ndef.op(), &status); + if (op_def == nullptr) return status; + + NameRangeMap inputs, outputs; + status = NameRangesForNode(ndef, *op_def, &inputs, &outputs); + if (!status.ok()) return status; + + return MemoryTypesForNode(device_type, ndef, *op_def, inputs, outputs, + input_memory_types, output_memory_types); +} + +namespace { + +bool FindArgInOp(const string& arg_name, + const protobuf::RepeatedPtrField& args) { + for (const auto& arg : args) { + if (arg_name == arg.name()) { + return true; + } + } + return false; +} + +} // namespace + +Status ValidateKernelRegistrations(const OpRegistryInterface* op_registry) { + Status unused_status; + for (const auto& key_registration : *GlobalKernelRegistry()) { + const KernelDef& kernel_def(key_registration.second.def); + const OpDef* op_def = op_registry->LookUp(kernel_def.op(), &unused_status); + if (op_def == nullptr) { + // TODO(josh11b): Make this a hard error. + LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString() + << "') for unknown op: " << kernel_def.op(); + continue; + } + for (const auto& host_memory_arg : kernel_def.host_memory_arg()) { + if (!FindArgInOp(host_memory_arg, op_def->input_arg()) && + !FindArgInOp(host_memory_arg, op_def->output_arg())) { + return errors::InvalidArgument("HostMemory arg '", host_memory_arg, + "' not found in OpDef: ", + SummarizeOpDef(*op_def)); + } + } + } + return Status::OK(); +} + +template <> +const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const { + return eigen_cpu_device(); +} + +template <> +const Eigen::GpuDevice& OpKernelContext::eigen_device() const { + return eigen_gpu_device(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h new file mode 100644 index 0000000000..34d588c6c9 --- /dev/null +++ b/tensorflow/core/framework/op_kernel.h @@ -0,0 +1,1250 @@ +#ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ +#define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ + +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace Eigen { +class ThreadPoolDevice; +class GpuDevice; +} // end namespace Eigen + +namespace tensorflow { + +namespace checkpoint { +class TensorSliceReaderCacheWrapper; +} // namespace checkpoint + +class AsyncOpKernel; +class OpKernelConstruction; // declared below +class OpKernelContext; // declared below +class ResourceMgr; + +// TODO(josh11b): Make reference-counted if needed. +class OpKernel { + public: + // OpKernel won't be instantiated by the scheduler, so you may perform + // expensive initialization in the descendant's constructor. + explicit OpKernel(OpKernelConstruction* context); + virtual ~OpKernel() {} + + // An OpKernel's computation can be either synchronous or + // asynchronous. + // + // Most OpKernels should compute synchronously. They should + // subclass OpKernel and override the Compute() method and have it + // return after completing the supplied work. + // + // A few special kernels might need to be asynchronous to bound the + // number of threads (e.g., network receive operations). These + // kernels must subclass AsyncOpKernel and override + // AsyncOpKernel::ComputeAsync(). + // + // In both cases, implementations of Compute() and ComputeAsync() + // get inputs and write outputs through the given OpKernelContext + // and returns a status via context->SetStatus(). They must be + // thread-safe. + + // Synchronous compute. + // + // "context" is guaranteed to be alive until Compute() returns. + virtual void Compute(OpKernelContext* context) = 0; + + // Returns nullptr iff this op kernel is synchronous. + virtual AsyncOpKernel* AsAsync() { return nullptr; } + + // Returns true iff this op kernel is considered "expensive". The + // runtime may use this flag to optimize graph execution for example + // to "inline" inexpensive kernels. + virtual bool IsExpensive() { return true; } + + // Accessors. + const NodeDef& def() const { return def_; } + const string& name() const { return def_.name(); } + const string& type_string() const { return def_.op(); } + + int num_inputs() const { return input_types_.size(); } + DataType input_type(int i) const { return input_types_[i]; } + const DataTypeVector& input_types() const { return input_types_; } + const MemoryTypeVector& input_memory_types() const { + return input_memory_types_; + } + + int num_outputs() const { return output_types_.size(); } + DataType output_type(int o) const { return output_types_[o]; } + const DataTypeVector& output_types() const { return output_types_; } + const MemoryTypeVector& output_memory_types() const { + return output_memory_types_; + } + + Status InputRange(const string& input_name, int* start, int* stop) const; + Status OutputRange(const string& output_name, int* start, int* stop) const; + + private: + const NodeDef def_; + const DataTypeVector input_types_; + const DataTypeVector output_types_; + NameRangeMap input_name_map_; + NameRangeMap output_name_map_; + MemoryTypeVector input_memory_types_; + MemoryTypeVector output_memory_types_; + + TF_DISALLOW_COPY_AND_ASSIGN(OpKernel); +}; + +class AsyncOpKernel : public OpKernel { + public: + using OpKernel::OpKernel; // Lift OpKernel constructors. + + // Asynchronous compute. + // + // Implementations of ComputeAsync() must run "done" to signal the + // completion of the computation. "context" is guaranteed to be + // alive until the "done" callback starts. + typedef std::function DoneCallback; + virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0; + + AsyncOpKernel* AsAsync() final { return this; } + + void Compute(OpKernelContext* context) final; +}; + +// Wraps a tensor that is held by an Op across calls to Compute(). For +// memory safety when using asynchronous devices like GPUs, the system +// must be notified when a Tensor is used inside an Op execution. The +// wrapper ensures that all uses of the Tensor are tracked, because in +// order to retrieve the Tensor the caller must use AccessTensor which +// notifies the context. +class PersistentTensor { + public: + PersistentTensor() {} + explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {} + + // Caller does not own the returned Tensor*. + Tensor* AccessTensor(OpKernelConstruction* context); + // Caller does not own the returned Tensor*. + Tensor* AccessTensor(OpKernelContext* context); + + // The check for initialization does not need to access the + // underlying tensor buffer. + bool IsInitialized() { return tensor_.IsInitialized(); } + + private: + Tensor tensor_; +}; + +class OpKernelConstruction { + public: + // TODO(yuanbyu): Probably reduce the number of arguments. + OpKernelConstruction(DeviceType device_type, DeviceBase* device, + Allocator* allocator, const NodeDef* node_def, + const OpDef* op_def, FunctionLibraryRuntime* flib, + const DataTypeSlice& input_types, + const DataTypeSlice& output_types, Status* status) + : device_type_(device_type), + device_(device), + allocator_(allocator), + def_(node_def), + op_def_(op_def), + flib_(flib), + input_types_(input_types), + output_types_(output_types), + status_(status) {} + + Env* env() const { return device_->env(); } + + // Allocation of tensors during kernel construction: + // + // It is legal to temporarily allocate scratch tensor storage during + // Op kernel construction. Scratch tensors should be allocated using + // allocate_temp below. Some kernels need to keep tensors in between + // invocations. If such a Tensor is allocated during kernel + // construction this must be done using allocate_persistent, and the + // Op may only store the returned PersistentTensor object. When the + // Tensor is needed in a subsequent invocation, it can be retrieved + // from the PersistentTensor using the AccessTensor method. This + // ensures that the system is made aware of any use of the tensor's + // allocated memory, which is needed for correctness on asynchronous + // devices such as GPUs. + + // Allocates a temporary Tensor of the specified type and shape. The + // Tensor must not be used after kernel construction is + // complete. See comment above. + Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp); + + // Allocates a Tensor of the specified type and shape which the Op + // plans to maintain as persistent state. out_persistent holds the + // PersistentTensor which is the object the caller should store. For + // convenience, if out_tensor is non-null then it will be filled in + // with a Tensor* pointing to the newly-allocated tensor which the + // caller can use instead of calling + // out_persistent->AccessTensor. The caller does not own out_tensor + // and should not keep a copy of it. See comment above. + Status allocate_persistent(DataType type, const TensorShape& shape, + PersistentTensor* out_persistent, + Tensor** out_tensor); + + // User-supplied configuration of this operation. + const NodeDef& def() const { return *def_; } + + // Op registered for this op type. + const OpDef& op_def() const { return *op_def_; } + + // For inspecting the inputs to this operation. + int num_inputs() const { return input_types_.size(); } + DataType input_type(int i) const { return input_types_[i]; } + const DataTypeSlice& input_types() const { return input_types_; } + + // For inspecting the outputs expected from this operation. + int num_outputs() const { return output_types_.size(); } + DataType output_type(int i) const { return output_types_[i]; } + const DataTypeSlice& output_types() const { return output_types_; } + + // If expected_inputs == inputs() and expected_outputs == output_types(), + // returns OK, else returns INVALID_ARGUMENT with an error message. + // Recommended for Ops with dynamic signatures. + Status MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs); + + // For recording configuration errors during construction. + void SetStatus(const Status& status) { status_->Update(status); } + const Status& status() const { return *status_; } + + // Look up the attr with name attr_name and set *value to its value. If no + // attr with attr_name is found in def(), or the attr does not have + // a matching type, a non-ok status will be returned. + template + Status GetAttr(const string& attr_name, T* value) const { + return GetNodeAttr(def(), attr_name, value); + } + + // May be used, e.g., to get GPU handles, etc. + // TODO(tucker): Add example usage. + DeviceBase* device() const { return device_; } + + // Return the device type. + const DeviceType& device_type() const { return device_type_; } + + // If not nullptr, the kernel can instantiate functions defined in + // the library. E.g., + // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...). + FunctionLibraryRuntime* function_library() const { return flib_; } + + private: + const DeviceType device_type_; + DeviceBase* const device_; + Allocator* allocator_; + const NodeDef* def_; + const OpDef* op_def_; + FunctionLibraryRuntime* flib_; + DataTypeSlice input_types_; + DataTypeSlice output_types_; + Status* status_; + + TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction); +}; + +// TODO(mrry): Consider converting to a random_access_iterator, and upgrading +// tensorflow::gtl::iterator_range to make the below container classes +// unnecessary. +template +class OpArgIterator { + public: + typedef OpArgIterator ME; + OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {} + bool operator==(const ME& rhs) { + DCHECK(list_ == rhs.list_); + return i_ == rhs.i_; + } + bool operator!=(const ME& rhs) { + DCHECK(list_ == rhs.list_); + return i_ != rhs.i_; + } + void operator++() { ++i_; } + ElementType& operator*() { return (*list_)[i_]; } + + private: + const ListType* const list_; + int i_; +}; + +// Utility class for representing a list of immutable input tensors +// that are passed to the op as a single named argument. +class OpInputList { + public: + typedef OpArgIterator Iterator; + OpInputList() : ctx_(nullptr), start_(0), stop_(0) {} + OpInputList(const OpKernelContext* ctx, int start, int stop) + : ctx_(ctx), start_(start), stop_(stop) {} + OpInputList& operator=(const OpInputList& other) = default; + const Tensor& operator[](int i) const; + int size() const { return stop_ - start_; } + Iterator begin() const { return Iterator(this, 0); } + Iterator end() const { return Iterator(this, size()); } + + private: + const OpKernelContext* ctx_; // not owned + int start_; + int stop_; +}; + +// Utility class for representing a list of mutable ("ref") input tensors +// that are passed to the op as a single named argument. +class OpMutableInputList { + public: + typedef OpArgIterator Iterator; + OpMutableInputList(OpKernelContext* ctx, int start, int stop) + : ctx_(ctx), start_(start), stop_(stop) {} + OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {} + OpMutableInputList& operator=(const OpMutableInputList& other) = default; + Tensor at(int i, bool lock_held); + mutex* ref_mutex(int i); + int size() const { return stop_ - start_; } + Iterator begin() const { return Iterator(this, 0); } + Iterator end() const { return Iterator(this, size()); } + + private: + OpKernelContext* ctx_; // not owned + int start_; + int stop_; +}; + +// Utility class for representing a list of output tensors that are +// grouped as a single named output. +class OpOutputList { + public: + typedef OpArgIterator Iterator; + OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {} + OpOutputList(OpKernelContext* ctx, int start, int stop) + : ctx_(ctx), start_(start), stop_(stop) {} + OpOutputList& operator=(const OpOutputList& other) = default; + Tensor* operator[](int i); + bool required(int i) const; + Status allocate(int i, const TensorShape& shape, Tensor** output); + void set(int i, const Tensor& tensor); + void set_ref(int i, mutex* mu, Tensor* tensor_for_ref); + int size() const { return stop_ - start_; } + Iterator begin() const { return Iterator(this, 0); } + Iterator end() const { return Iterator(this, size()); } + + private: + OpKernelContext* ctx_; // not owned + int start_; + int stop_; +}; + +// Holds a tensor or tensor reference. For tensor references, we need +// a mutex to prevent concurrent access to the tensor. +struct TensorValue { + TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {} + TensorValue(Tensor* t) // NOLINT(runtime/explicit) + : mutex_if_ref(nullptr), + tensor(t) {} + TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {} + Tensor* operator->() const { return tensor; } + bool is_ref() const { return mutex_if_ref != nullptr; } + + mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref + Tensor* tensor; +}; + +class OpKernelContext { + public: + // The first element of a WrappedAllocator is a "base" Allocator and + // the second element is that Allocator wrapped by a + // TrackingAllocator + typedef std::pair WrappedAllocator; + + // TODO(zhifengc): Do some cleanup of Params. + struct Params { + // The op kernel being computed. + OpKernel* op_kernel = nullptr; + + // The device on which the kernel is running. + DeviceBase* device = nullptr; + + bool track_allocations = false; + std::function output_alloc_attr = nullptr; + + // Shared resources accessible by this op kernel invocation. + ResourceMgr* resource_manager = nullptr; + + // Per-step resources accessible by this op kernel invocation. + ResourceMgr* step_resource_manager = nullptr; + + // Mechanism used by this op kernel invocation to communicate with + // computations running on other devices. + Rendezvous* rendezvous = nullptr; + + // Mechanism used by this op kernel invocation to register a callback + // for its cancellation. + CancellationManager* cancellation_manager = nullptr; + + // Inputs to this op kernel. + const gtl::InlinedVector* inputs = nullptr; + bool is_input_dead = false; + + const gtl::InlinedVector* input_alloc_attrs = + nullptr; + + // Device contexts. + const gtl::InlinedVector* input_device_contexts = + nullptr; + DeviceContext* op_device_context = nullptr; + + // Control-flow op supports. + FrameAndIter frame_iter; + + // Function call supports. + FunctionCallFrame* call_frame = nullptr; + FunctionLibraryRuntime* function_library = nullptr; + + // TensorSliceReaderCache support. + checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; + }; + explicit OpKernelContext(const Params& params); + ~OpKernelContext(); + + Env* env() const { return params_.device->env(); } + + // Input/output signature. + + int num_inputs() const { return params_.inputs->size(); } + DataType input_dtype(int index) const; + int num_outputs() const { return outputs_.size(); } + DataType expected_output_dtype(int index) const; + + // Input + + // Returns an immutable input tensor. May only be used for non-Ref + // inputs. For Ref inputs use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + // TODO(mrry): Convert this to return Status. + const Tensor& input(int index) const; + + // Returns the named immutable input tensor in "tensor", as defined + // in the OpDef. May only be used for non-Ref inputs. For Ref inputs + // use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + // REQUIRES: the named input must not be a list. + Status input(const string& name, const Tensor** tensor) const; + + // Returns the named list-valued immutable input in "list", as + // defined in the OpDef. If the named output is not list-valued, + // returns a one-element list. May only be used for non-Ref + // inputs. For Ref inputs use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + Status input_list(const string& name, OpInputList* list) const; + + // For mutable inputs, use the following together to make sure there + // is no concurrent access to mutable_input(), e.g.: + // { + // Tensor& t = context->mutable_input(index); + // mutex_lock lock(*context->input_ref_mutex(index)); + // // modify the values in t + // } + // REQUIRES: IsRefType(input_dtype(index)) + // TODO(mrry): Convert this to return Status. + mutex* input_ref_mutex(int index); + Status input_ref_mutex(const string& name, mutex** out_mutex); + + // Returns a mutable input tensor. Must be used to access Ref + // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may + // modify the values stored in the Tensor buffer, and modifications + // will be visible to other Ops reading the same ref tensor. If + // !lock_held the input mutex will be acquired before returning the + // Tensor. + // TODO(mrry): + // Convert this to return Status. + Tensor mutable_input(int index, bool lock_held); + + // Returns the named mutable input tensor in "tensor", as defined in + // the OpDef. Must be used to access Ref inputs. The values stored + // in the Tensor buffer may be modified, and modifications will be + // visible to other Ops reading the same ref tensor. If !lock_held + // the input mutex will be acquired before returning the Tensor. + // REQUIRES: the named input must not be a list. + // REQUIRES: the named input must be a ref tensor. + Status mutable_input(const string& name, Tensor* tensor, bool lock_held); + + // Returns the named list-valued mutable input in "list", as defined + // in the OpDef. If the named intput is not list-valued, returns a + // one-element list. Must be used to access Ref inputs. The values + // stored in the Tensor buffer may be modified, and modifications + // will be visible to other Ops reading the same ref tensor. + // REQUIRES: the named input must be a ref tensor. + Status mutable_input_list(const string& name, OpMutableInputList* list); + + // Replace the corresponding Ref Input to use the storage buffer + // used by tensor. If !lock_held the input mutex will be acquired + // before returning the Tensor. + // REQUIRES: IsRefType(input_dtype(index)). + void replace_ref_input(int index, const Tensor& tensor, bool lock_held); + + // Replace the corresponding named Ref Input to use the storage + // buffer used by tensor. If !lock_held the input mutex will be + // acquired before returning the Tensor. + // REQUIRES: IsRefType(input_dtype(index)). + Status replace_ref_input(const string& name, const Tensor& tensor, + bool lock_held); + + // Set the output Ref Tensor at output_index to be an alias of the + // input Ref Tensor at input_index. + // REQUIRES: IsRefType(input_dtype(input_index)). + // REQUIRES: IsRefType(output_dtype(output_index)). + void forward_ref_input_to_ref_output(int input_index, int output_index); + + // Deletes the Tensor object used as the Ref Input at + // input_index. This is not usually necessary and should be used + // with caution. If !lock_held the input mutex will be acquired + // before returning the Tensor. + // REQUIRES: IsRefType(input_dtype(input_index)). + void delete_ref_input(int input_index, bool lock_held); + + // Return true if there is input at the given index. An operator has no + // input at index if its tensor is null. This is primarily used by the + // merge operator. + // TODO(mrry): Convert this to return Status. + bool has_input(int index) const; + + // Returns true if all inputs are the same shape, otherwise sets the + // status to a non-OK value and returns false. + // Usage: if (!context->ValidateInputsAreSameShape(this)) return; + bool ValidateInputsAreSameShape(OpKernel* op); + + // Output + + // Returns the named list-valued output in "list", as defined in the OpDef. + // If the named output is not list-valued, returns a one-element list. + Status output_list(const string& name, OpOutputList* list); + + // If output_required(index) returns true, the OpKernel's Compute() method + // should call allocate_output(index, ...), set_output(index, ...), + // set_output_ref(index, ...), or set the status to a non-ok value. + // If it returns false, it may output, but is not required to do so. + // TODO(mrry): Convert this to return Status, and implement a string + // name version. + bool output_required(int index) const { + return true; // TODO(josh11b): implement + } + + // Allocation of tensors during kernel execution inside the Compute + // method: + // + // There are three methods to allocate Tensors when an Op kernel + // executes. + // + // 1) allocate_persistent. This is only needed for Tensors that will + // be stored by the Op between invocations, and it *must* be used + // for those Tensors. The call returns a PersistentTensor, and that + // is the only object the Op is allowed to hold on to between + // invocations. When the Tensor is needed in a subsequent + // invocation, it can be retrieved from the PersistentTensor using + // the AccessTensor method. This ensures that the system is made + // aware of any use of the tensor's allocated memory, which is + // needed for correctness on asynchronous devices such as GPUs. + // + // 2) allocate_output. This should be used to allocate any tensor + // that is going to be used as an output from the Op at the end of + // the current execution. The caller indicates which output the + // Tensor will be assigned to, and the call returns the + // newly-allocated Tensor. The Tensor can subsequently be assigned + // to during kernel execution, and will be used as the designated + // output when the kernel execution completes. + // + // 3) allocate_temp. This should be used to allocate any scratch + // storage that is needed while the kernel is executing, and will + // not be retained by the Op. + // + // In some cases a Tensor needs to be used as an output even though + // it was previously allocated elsewhere. The Tensor may have been + // passed as an input, or stored in a PersistentTensor during a + // previous kernel execution, or allocated earlier in the kernel + // execution at a time when it was not known which output it would + // be assigned to. In this case the kernel can use set_output or + // set_output_ref to indicate that the tensor should be used as the + // designated output. It is legal to use any previously-allocated + // Tensor as an argument to set_output or set_output_ref, including + // Tensors allocated via allocate_temp. There may be a performance + // penalty to using a Tensor that was not allocated using + // allocate_output. This is because allocate_output uses the + // AllocatorAttributes stored in output_alloc_attr for the + // designated output. In some cases, using the wrong attributes may + // cause an extra copy of the Tensor's buffer. + + // Allocates output for the specified output index with shape. + // OpKernelContext retains ownership of the returned pointer. See + // comment above. + // + // If memory allocation fails, returns an error status. + // + // REQUIRES: !IsRefType(expected_output_dtype(index)) + Status allocate_output(int index, const TensorShape& shape, + Tensor** tensor) TF_MUST_USE_RESULT; + Status allocate_output(const string& name, const TensorShape& shape, + Tensor** tensor) TF_MUST_USE_RESULT; + // The following methods use the supplied attributes instead of + // those in output_alloc_attr. The caller is responsible for + // ensuring that the attributes are "compatible" with the + // output_alloc_attr, e.g. the tensor is allocated on the correct + // device. See comment above. + Status allocate_output(int index, const TensorShape& shape, Tensor** tensor, + AllocatorAttributes attr) TF_MUST_USE_RESULT; + Status allocate_output(const string& name, const TensorShape& shape, + Tensor** tensor, + AllocatorAttributes attr) TF_MUST_USE_RESULT; + + // Allocates a temporary Tensor of the specified type and + // shape. Devices such as GPUs that enqueue Ops for lazy execution + // may retain references to the temporary tensors after the Op's + // Compute method has run. See comment above. + Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp, AllocatorAttributes attr); + Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp) { + return allocate_temp(type, shape, out_temp, AllocatorAttributes()); + } + + // Allocates a Tensor of the specified type and shape which the Op + // plans to maintain as persistent state. out_persistent holds the + // PersistentTensor which is the object the caller should store. For + // convenience, if out_tensor is non-null then it will be filled in + // with a Tensor* pointing to the newly-allocated tensor which the + // caller can use instead of calling + // out_persistent->AccessTensor. The caller does not own out_tensor + // and should not keep a copy of it. See comment above. + Status allocate_persistent(DataType type, const TensorShape& shape, + PersistentTensor* out_persistent, + Tensor** out_tensor, AllocatorAttributes attr); + Status allocate_persistent(DataType type, const TensorShape& shape, + PersistentTensor* out_persistent, + Tensor** out_tensor) { + return allocate_persistent(type, shape, out_persistent, out_tensor, + AllocatorAttributes()); + } + + // Copies a tensor (allocated by the caller) to the specified output + // index. REQUIRES: !IsRefType(expected_output_dtype(index)) + // REQUIRES: 'tensor' must have the same MemoryType as + // output_memory_types[index]. See comment above. + // TODO(mrry): Convert this to return Status. + void set_output(int index, const Tensor& tensor); + Status set_output(const string& name, const Tensor& tensor); + + // To output a reference. Caller retains ownership of mu and tensor_for_ref, + // and they must outlive all uses within the step. See comment above. + // REQUIRES: IsRefType(expected_output_dtype(index)) + // TODO(mrry): Convert this to return Status. + void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref); + Status set_output_ref(const string& name, mutex* mu, Tensor* tensor_for_ref); + + // Returns nullptr if allocate_output() or set_output() have not been called. + // TODO(mrry): Convert this to return Status. + Tensor* mutable_output(int index); + Status mutable_output(const string& name, Tensor** tensor); + + // Transfers ownership of an output tensor to the caller. + // NOTE: For non-reference outputs, the caller takes responsibility + // for deletion. For reference outputs, the caller does NOT take + // responsibility for deletion. + // TODO(mrry): Convert this to return Status. + TensorValue release_output(int index); + Status release_output(const string& name, TensorValue* value); + + // Records device specific state about how the input tensors were + // computed. + // + // If using the templated function, the type must be a subclass + // of DeviceContext. + // + // Get the DeviceContext used for the index input. Returns nullptr + // if no DeviceContext was provided. + template + T* input_device_context(int index); + DeviceContext* input_device_context(int index); + + // Return the DeviceContext that should be used for this Op. + // + // If using the templated function, the type must be a subclass + // of DeviceContext. + // + // Returns nullptr if the device did not provide one. + template + T* op_device_context(); + DeviceContext* op_device_context() { + DeviceContext* ret = params_.op_device_context; + if (ret == nullptr) { + auto* dev_info = device()->tensorflow_gpu_device_info(); + if (dev_info) ret = dev_info->default_context; + } + return ret; + } + + AllocatorAttributes input_alloc_attr(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.input_alloc_attrs->size()); + return (*params_.input_alloc_attrs)[index]; + } + + AllocatorAttributes output_alloc_attr(int index) const { + return params_.output_alloc_attr(index); + } + + gtl::InlinedVector wrapped_allocators() const { + mutex_lock lock(mu_); + gtl::InlinedVector retrieved = wrapped_allocators_; + return retrieved; + } + + // Communication. + // + // An op kernel communicates with outside environment through + // Rendezvous Send() and Recv(). + Rendezvous* rendezvous() const { return params_.rendezvous; } + + // Function call support. + // + // If this kernel invocation is within a function execution, + // call_frame() returns the call frame for the function call. + FunctionCallFrame* call_frame() const { return params_.call_frame; } + + // If not nullptr, the kernel invoke functions defined in the + // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...). + FunctionLibraryRuntime* function_library() const { + return params_.function_library; + } + + // Shared resources accessible to this kernel. + ResourceMgr* resource_manager() const { return params_.resource_manager; } + + checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const { + return params_.slice_reader_cache; + } + + // Execution. + // + // OpKernels can use these eigen devices to carry out their + // numerical computation. + const Eigen::ThreadPoolDevice& eigen_cpu_device() const { + return *device()->eigen_cpu_device(); + } + const Eigen::GpuDevice& eigen_gpu_device() const { + return eigen_gpu_device_->device(); + } + template + const EigenDeviceType& eigen_device() const; + + // Error handling. + + // If expected_inputs == inputs() and expected_outputs == output_types(), + // returns OK, else returns INVALID_ARGUMENT with an error message. + // Recommended for Ops with dynamic signatures, where validation can only + // be performed at runtime. + Status MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs); + + // An OpKernel should call SetStatus() if Compute() encounters an + // error. + void SetStatus(const Status& status) { status_.Update(status); } + const Status& status() const { return status_; } + + // Cancellation. + // + // EXPERIMENTAL. See the implementation in tensorflow::TensorQueue for an + // example of how to use this API. + CancellationManager* cancellation_manager() const { + return params_.cancellation_manager; + } + + // Other accessors. + + // For control flow. + FrameAndIter frame_iter() const { return params_.frame_iter; } + bool is_input_dead() const { return params_.is_input_dead; } + bool* is_output_dead() { return &is_output_dead_; } + + // May be used, e.g., to get GPU handles, etc. + // TODO(tucker): Add example usage. + DeviceBase* device() const { return params_.device; } + + // Access to list of temporary tensors. + int num_temps(); + Tensor* temp(int index); + + // Access to information about whether each output was newly + // allocated or copied from an existing tensor + AllocationType output_allocation_type(int index) const { + return output_allocation_types_[index]; + } + + private: + Allocator* get_allocator(AllocatorAttributes attr) { + Allocator* allocator = params_.device->GetAllocator(attr); + if (params_.track_allocations) { + mutex_lock lock(mu_); + for (const auto& wrapped : wrapped_allocators_) { + if (wrapped.first == allocator) { + return wrapped.second; + } + } + TrackingAllocator* wrapped_allocator = new TrackingAllocator(allocator); + wrapped_allocators_.push_back( + std::make_pair(allocator, wrapped_allocator)); + return wrapped_allocator; + } else { + return allocator; + } + } + + // Per-step resource manager for use by white-listed internal ops. + friend class TemporaryVariableOp; + friend class DestroyTemporaryVariableOp; + ResourceMgr* step_resource_manager() const { + return params_.step_resource_manager; + } + + // Internal common method used when allocating tensor memory + Status allocate_tensor(DataType type, const TensorShape& shape, + Tensor* out_tensor, AllocatorAttributes attr); + + // This is called by PersistentTensor::AccessTensor whenever the + // wrapped tensor is retrieved, to ensure the runtime knows that the + // Tensor is being accessed within an Op. This is necessary for + // memory safety of devices like GPUs that queue Ops for + // asynchronous execution after the Compute() method completes. + friend class PersistentTensor; + void NotifyUseOfPersistentTensor(const Tensor& tensor); + + Status status_; + Params params_; // immutable after construction. + const PerOpGpuDevice* eigen_gpu_device_; // owned, with a per-op + // wrapped allocator + mutable mutex mu_; // mutable so const accessors can acquire the lock + gtl::InlinedVector wrapped_allocators_ GUARDED_BY(mu_); + gtl::InlinedVector outputs_; + gtl::InlinedVector output_allocation_types_; + gtl::InlinedVector temp_tensors_; + bool is_output_dead_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext); +}; + +// Register your OpKernel by specifying the Op's name, the device the +// kernel runs on, any type attr constraints for this kernel, any +// host-memory args, and the class to instantiate. Examples: +// +// // A kernel that supports all types. +// REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp); +// +// // The following are equivalent ways of specifying that the kernel only +// // works if the "T" type attr is set to DT_FLOAT. +// REGISTER_KERNEL_BUILDER( +// Name("Sub").Device(DEVICE_CPU).TypeConstraint("T"), +// SubOp); +// // (You would then repeat this for every type supported by "Sub".) +// +// // This form allows you to specify a list of types as the constraint. +// REGISTER_KERNEL_BUILDER(Name("Sub") +// .Device(DEVICE_CPU) +// .TypeConstraint("T", {DT_FLOAT}), +// SubOp); +// +// // A kernel that expects one of the input tensors in host memory. +// REGISTER_KERNEL_BUILDER( +// Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp); +// +// See kernel_def_builder for details. + +// Instantiate an OpKernel that has been registered. Returns nullptr +// if no operation for that type of device / input signature combination +// (and a NOT_FOUND *status), or there is an error in construction (and +// an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership +// of the returned pointer. +// EXPECTED USAGE: unique_ptr op = CreateOpKernel(...); +// REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +std::unique_ptr CreateOpKernel(DeviceType device_type, + DeviceBase* device, + Allocator* allocator, + const NodeDef& def, Status* status); +Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + const NodeDef& def, OpKernel** kernel); + +// Returns into 'device_types' the subset of prioritized_types that this +// binary has registered for the given NodeDef. +// +// REQUIRES: * 'device_types' is not nullptr. +// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +Status SupportedDeviceTypesForNode( + const std::vector& prioritized_types, const NodeDef& def, + DeviceTypeVector* device_types); + +// Returns into *{input,output}_memory_types the memory type of each +// {input,output} tensor. +// +// REQUIRES: * '*_memory_types' is not nullptr. +// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef, + const OpDef& op_def, + const NameRangeMap& input_name_map, + const NameRangeMap& output_name_map, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types); + +Status MemoryTypesForNode(const OpRegistryInterface* op_registry, + DeviceType device_type, const NodeDef& ndef, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types); + +// Call once after Op registration has completed. +Status ValidateKernelRegistrations(const OpRegistryInterface* op_registry); + +// ----------------------------------------------------------------------------- +// OpKernel registration implementation follows, please ignore. + +// Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax. +namespace register_kernel { +typedef ::tensorflow::KernelDefBuilder Name; +} // namespace register_kernel + +#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ + REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__) + +#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \ + REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__) + +#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ + static ::tensorflow::kernel_factory::OpKernelRegistrar \ + registrar__body__##ctr##__object( \ + ::tensorflow::register_kernel::kernel_builder.Build(), \ + +[](::tensorflow::OpKernelConstruction* context) \ + -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); }) + +namespace kernel_factory { + +class OpKernelRegistrar { + public: + typedef OpKernel* (*Factory)(OpKernelConstruction*); + OpKernelRegistrar(const KernelDef* kernel_def, Factory factory); +}; + +} // namespace kernel_factory + +// ----------------------------------------------------------------------------- +// Template and inline method implementations, please ignore + +inline DataType OpKernelContext::input_dtype(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + const TensorValue& value((*params_.inputs)[index]); + if (value.is_ref()) { + return MakeRefType(value->dtype()); + } else { + return value->dtype(); + } +} + +inline DataType OpKernelContext::expected_output_dtype(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.op_kernel->output_types().size()); + return params_.op_kernel->output_type(index); +} + +inline const Tensor& OpKernelContext::input(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + DCHECK(!(*params_.inputs)[index].is_ref()); + return *((*params_.inputs)[index].tensor); +} + +inline Tensor OpKernelContext::mutable_input(int index, bool lock_held) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + DCHECK((*params_.inputs)[index].is_ref()); + // return a copy of the Ref acquired while holding the mutex + if (lock_held) { + return *((*params_.inputs)[index].tensor); + } else { + mutex_lock l(*input_ref_mutex(index)); + return *((*params_.inputs)[index].tensor); + } +} + +inline void OpKernelContext::replace_ref_input(int index, const Tensor& tensor, + bool lock_held) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + DCHECK((*params_.inputs)[index].is_ref()); + // should only modify the tensor while holding the mutex + if (lock_held) { + *(*params_.inputs)[index].tensor = tensor; + } else { + mutex_lock l(*input_ref_mutex(index)); + *(*params_.inputs)[index].tensor = tensor; + } +} + +inline void OpKernelContext::forward_ref_input_to_ref_output(int input_index, + int output_index) { + DCHECK_GE(input_index, 0); + DCHECK_LT(input_index, params_.inputs->size()); + DCHECK((*params_.inputs)[input_index].is_ref()); + set_output_ref(output_index, (*params_.inputs)[input_index].mutex_if_ref, + (*params_.inputs)[input_index].tensor); +} + +inline void OpKernelContext::delete_ref_input(int index, bool lock_held) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + DCHECK((*params_.inputs)[index].is_ref()); + // should only modify the tensor while holding the mutex + if (lock_held) { + delete (*params_.inputs)[index].tensor; + } else { + mutex_lock l(*input_ref_mutex(index)); + delete (*params_.inputs)[index].tensor; + } +} + +// no input if tensor == nullptr. +inline bool OpKernelContext::has_input(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + return (*params_.inputs)[index].tensor != nullptr; +} + +inline mutex* OpKernelContext::input_ref_mutex(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + DCHECK((*params_.inputs)[index].is_ref()); + return (*params_.inputs)[index].mutex_if_ref; +} + +inline Status OpKernelContext::allocate_output(int index, + const TensorShape& shape, + Tensor** output) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + DCHECK(params_.output_alloc_attr); + AllocatorAttributes attr = params_.output_alloc_attr(index); + return allocate_output(index, shape, output, attr); +} + +inline Status OpKernelContext::allocate_tensor(DataType type, + const TensorShape& shape, + Tensor* out_tensor, + AllocatorAttributes attr) { + Allocator* a = get_allocator(attr); + Tensor new_tensor(a, type, shape); + + if (!new_tensor.IsInitialized() && shape.num_elements() > 0) { + return errors::ResourceExhausted("OOM when allocating tensor with shape", + shape.DebugString()); + } + *out_tensor = new_tensor; + return Status::OK(); +} + +inline Status OpKernelContext::allocate_output(int index, + const TensorShape& shape, + Tensor** output, + AllocatorAttributes attr) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + // Record the fact that this output tensor was allocated by the Op + DCHECK_LT(index, output_allocation_types_.size()); + output_allocation_types_[index] = AT_ALLOCATED; + const DataType type = params_.op_kernel->output_type(index); + DCHECK(!IsRefType(type)); + DCHECK(mutable_output(index) == nullptr); + Tensor* output_tensor = new Tensor(); + Status s = allocate_tensor(type, shape, output_tensor, attr); + if (s.ok()) { + outputs_[index] = TensorValue(output_tensor); + *output = outputs_[index].tensor; + } + return s; +} + +inline Status OpKernelContext::allocate_temp(DataType type, + const TensorShape& shape, + Tensor* out_temp, + AllocatorAttributes attr) { + Status s = allocate_tensor(type, shape, out_temp, attr); + if (s.ok()) { + if (params_.device->SaveTemporaryTensors()) { + // keep a reference to the underlying memory around + temp_tensors_.push_back(new Tensor(*out_temp)); + } + } + return s; +} + +inline Status OpKernelContext::allocate_persistent( + DataType type, const TensorShape& shape, PersistentTensor* out_persistent, + Tensor** out_tensor, AllocatorAttributes attr) { + // TODO(misard) add specific memory tracking for persistent tensors + Tensor persistent; + Status s = allocate_tensor(type, shape, &persistent, attr); + if (s.ok()) { + *out_persistent = PersistentTensor(persistent); + // This call saves a reference to the newly-allocated tensor if we + // are saving temporary tensors + Tensor* allocated = out_persistent->AccessTensor(this); + if (out_tensor) { + *out_tensor = allocated; + } + } + return s; +} + +inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) { + if (t.IsInitialized() && params_.device->SaveTemporaryTensors()) { + // keep a reference to the underlying memory around + temp_tensors_.push_back(new Tensor(t)); + } +} + +inline int OpKernelContext::num_temps() { return temp_tensors_.size(); } + +inline Tensor* OpKernelContext::temp(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, temp_tensors_.size()); + return temp_tensors_[index]; +} + +inline void OpKernelContext::set_output(int index, const Tensor& tensor) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + // Record the fact that this output tensor was set by the Op + DCHECK_LT(index, output_allocation_types_.size()); + output_allocation_types_[index] = AT_EXISTING; + DCHECK(!IsRefType(params_.op_kernel->output_type(index))); + DCHECK_EQ(mutable_output(index), nullptr); + outputs_[index] = TensorValue(new Tensor(tensor)); +} + +inline void OpKernelContext::set_output_ref(int index, mutex* mu, + Tensor* tensor_for_ref) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + // Record the fact that this output tensor was set by reference the Op + DCHECK_LT(index, output_allocation_types_.size()); + output_allocation_types_[index] = AT_REF; + DCHECK(IsRefType(params_.op_kernel->output_type(index))); + outputs_[index] = TensorValue(mu, tensor_for_ref); +} + +inline Tensor* OpKernelContext::mutable_output(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + return outputs_[index].tensor; +} + +inline TensorValue OpKernelContext::release_output(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + TensorValue value = outputs_[index]; + outputs_[index] = TensorValue(); + return value; +} + +template +T* OpKernelContext::op_device_context() { + static_assert(std::is_base_of::value, + "T is not a subclass of DeviceContext"); + return static_cast(op_device_context()); +} + +template +T* OpKernelContext::input_device_context(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.input_device_contexts->size()); + static_assert(std::is_base_of::value, + "T is not a subclass of DeviceContext"); + return static_cast((*params_.input_device_contexts)[index]); +} + +inline DeviceContext* OpKernelContext::input_device_context(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.input_device_contexts->size()); + return (*params_.input_device_contexts)[index]; +} + +inline const Tensor& OpInputList::operator[](int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->input(start_ + i); +} + +inline mutex* OpMutableInputList::ref_mutex(int i) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->input_ref_mutex(start_ + i); +} + +inline Tensor OpMutableInputList::at(int i, bool lock_held) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->mutable_input(start_ + i, lock_held); +} + +inline Tensor* OpOutputList::operator[](int i) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->mutable_output(start_ + i); +} + +inline bool OpOutputList::required(int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->output_required(start_ + i); +} + +inline Status OpOutputList::allocate(int i, const TensorShape& shape, + Tensor** output) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->allocate_output(start_ + i, shape, output); +} + +inline void OpOutputList::set(int i, const Tensor& tensor) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + ctx_->set_output(start_ + i, tensor); +} + +inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + ctx_->set_output_ref(i, mu, tensor_for_ref); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc new file mode 100644 index 0000000000..9400ef24f8 --- /dev/null +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -0,0 +1,803 @@ +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include + +class DummyKernel : public tensorflow::OpKernel { + public: + explicit DummyKernel(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(tensorflow::OpKernelContext* context) override {} +}; + +// Test that registration works outside a namespace. +REGISTER_OP("Test1").Input("a: float").Input("b: int32").Output("o: uint8"); +REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU), + DummyKernel); + +namespace foo { +bool match_signature_ = false; + +// Test that registration works inside a different namespace. +class TestOp2 : public ::tensorflow::OpKernel { + public: + explicit TestOp2(::tensorflow::OpKernelConstruction* context) + : OpKernel(context) { + ::tensorflow::Status status = context->MatchSignature( + {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32}); + match_signature_ = status.ok(); + context->SetStatus(status); + } + void Compute(::tensorflow::OpKernelContext* context) override {} +}; + +REGISTER_OP("Test2").Input("i: T").Output("o: T").Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("Test2") + .Device(::tensorflow::DEVICE_GPU) + .HostMemory("i") + .HostMemory("o"), + TestOp2); +} // namespace foo + +namespace tensorflow { + +// Two operations with the same name but different devices. +REGISTER_OP("Test3").Input("a: T").Input("b: T").Attr("T: type"); + +class TestOp3Cpu : public tensorflow::OpKernel { + public: + explicit TestOp3Cpu(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +REGISTER_KERNEL_BUILDER( + Name("Test3").Device(DEVICE_CPU).TypeConstraint("T"), TestOp3Cpu); + +namespace { + +class TestOp3Gpu : public tensorflow::OpKernel { + public: + explicit TestOp3Gpu(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +REGISTER_KERNEL_BUILDER( + Name("Test3").Device(DEVICE_GPU).TypeConstraint("T"), TestOp3Cpu); + +// An Op registered for both +REGISTER_OP("Test4").Input("i: float").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel); + +static std::vector DeviceTypes() { + return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}; +} + +class OpKernelTest : public ::testing::Test { + public: + OpKernelTest() : device_(Env::Default()) {} + + protected: + NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs) { + NodeDefBuilder builder(op_type + "-op", op_type); + for (DataType dt : inputs) { + builder.Input(FakeInput(dt)); + } + NodeDef node_def; + TF_CHECK_OK(builder.Finalize(&node_def)); + return node_def; + } + + void ExpectEqual(const string& what, const DataTypeVector& expected, + const DataTypeVector& observed) { + EXPECT_EQ(expected.size(), observed.size()) << what; + const int size = std::min(expected.size(), observed.size()); + for (int i = 0; i < size; ++i) { + bool match = TypesCompatible(expected[i], observed[i]); + EXPECT_TRUE(match) << what << " i:" << i << ", expected: " << expected[i] + << ", observed: " << observed[i]; + } + } + + void ExpectSuccess(const string& op_type, DeviceType device_type, + const DataTypeVector& inputs, + const DataTypeVector& outputs) { + Status status; + std::unique_ptr op( + CreateOpKernel(device_type, &device_, cpu_allocator(), + CreateNodeDef(op_type, inputs), &status)); + EXPECT_TRUE(status.ok()) << status; + EXPECT_TRUE(op != nullptr); + if (op != nullptr) { + ExpectEqual("inputs", op->input_types(), inputs); + ExpectEqual("outputs", op->output_types(), outputs); + } + } + + void ExpectFailure(const string& ascii_node_def, DeviceType device_type, + error::Code code) { + NodeDef node_def; + protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def); + Status status; + std::unique_ptr op(CreateOpKernel( + device_type, &device_, cpu_allocator(), node_def, &status)); + EXPECT_TRUE(op == nullptr); + EXPECT_FALSE(status.ok()); + if (!status.ok()) { + LOG(INFO) << "Status message: " << status.error_message(); + EXPECT_EQ(code, status.code()); + } + } + + private: + DeviceBase device_; +}; + +TEST_F(OpKernelTest, SuccessCpu) { + ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT, DT_INT32}, {DT_UINT8}); + ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT_REF, DT_INT32}, {DT_UINT8}); +} + +TEST_F(OpKernelTest, SuccessGpu) { + foo::match_signature_ = false; + ExpectSuccess("Test2", DEVICE_GPU, {DT_INT32}, {DT_INT32}); + EXPECT_TRUE(foo::match_signature_); +} + +TEST_F(OpKernelTest, SuccessBothCpuAndGpu) { + ExpectSuccess("Test3", DEVICE_CPU, {DT_INT8, DT_INT8}, {}); + ExpectSuccess("Test3", DEVICE_GPU, {DT_FLOAT, DT_FLOAT}, {}); +} + +TEST_F(OpKernelTest, CpuTypeRegistered) { + NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(1, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); +} + +TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) { + { + // Try a node def of an op that is registered for a specific type + // only on CPU. + NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(1, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); + } + { + // Try a node def of an op that is registered for a specific type + // only on GPU. + NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(1, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); + } + { + // Try a node def of an op that is only registered for other types. + NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(0, devs.size()); + } + + { + // Try a node def of an op that is registered for both. + NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(2, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1]); + } +} + +TEST_F(OpKernelTest, NotFound) { + const auto not_found = error::NOT_FOUND; + // Something with that op type name exists, but only with a + // different DeviceType. + ExpectFailure(CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}).DebugString(), + DEVICE_GPU, not_found); + ExpectFailure(CreateNodeDef("Test3", {DT_INT8, DT_INT8}).DebugString(), + DEVICE_GPU, not_found); + ExpectFailure(CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}).DebugString(), + DEVICE_CPU, not_found); + + // No kernel with that signature registered. + ExpectFailure(CreateNodeDef("Test3", {DT_INT32, DT_INT32}).DebugString(), + DEVICE_GPU, not_found); + + // Nothing with that op type name exists. + ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_CPU, not_found); + ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_GPU, not_found); +} + +TEST_F(OpKernelTest, TooFewInputs) { + const auto invalid = error::INVALID_ARGUMENT; + NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); + node_def.clear_input(); + ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid); + node_def.add_input("a"); + ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid); +} + +TEST_F(OpKernelTest, TooManyInputs) { + const auto invalid = error::INVALID_ARGUMENT; + NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); + node_def.add_input("c"); + ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid); +} + +TEST_F(OpKernelTest, MatchSignatureFailes) { + const auto invalid = error::INVALID_ARGUMENT; + foo::match_signature_ = true; + ExpectFailure(CreateNodeDef("Test2", {DT_FLOAT}).DebugString(), DEVICE_GPU, + invalid); + EXPECT_FALSE(foo::match_signature_); +} + +class DummyDevice : public DeviceBase { + public: + DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} + bool SaveTemporaryTensors() const override { return save_; } + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } + + private: + bool save_; +}; + +TEST_F(OpKernelTest, SaveTempFalse) { + Env* env = Env::Default(); + OpKernelContext::Params params; + params.device = new DummyDevice(env, false); + Status status; + std::unique_ptr op( + CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), + CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), &status)); + EXPECT_TRUE(status.ok()); + params.op_kernel = op.get(); + OpKernelContext* ctx = new OpKernelContext(params); + + Tensor t; + EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t)); + + EXPECT_EQ(0, ctx->num_temps()); + + delete ctx; + delete params.device; +} + +TEST_F(OpKernelTest, SaveTempTrue) { + Env* env = Env::Default(); + OpKernelContext::Params params; + params.device = new DummyDevice(env, true); + Status status; + std::unique_ptr op( + CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), + CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), &status)); + EXPECT_TRUE(status.ok()); + params.op_kernel = op.get(); + OpKernelContext* ctx = new OpKernelContext(params); + + Tensor t; + EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t)); + + EXPECT_EQ(1, ctx->num_temps()); + + delete ctx; + delete params.device; +} + +class OpKernelBuilderTest : public ::testing::Test { + protected: + // Each attr is described by a "name|type|value". + NodeDef CreateNodeDef(const string& op_type, + const std::vector& attrs) { + NodeDef node_def; + node_def.set_name(op_type + "-op"); + node_def.set_op(op_type); + for (const string& attr_desc : attrs) { + std::vector parts = str_util::Split(attr_desc, '|'); + CHECK_EQ(parts.size(), 3); + AttrValue attr_value; + CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc; + node_def.mutable_attr()->insert( + AttrValueMap::value_type(parts[0], attr_value)); + } + return node_def; + } + + std::unique_ptr ExpectSuccess(const string& op_type, + DeviceType device_type, + const std::vector& attrs, + DataTypeSlice input_types = {}) { + Status status; + NodeDef def = CreateNodeDef(op_type, attrs); + for (size_t i = 0; i < input_types.size(); ++i) { + def.add_input("a:0"); + } + + Env* env = Env::Default(); + DeviceBase device(env); + + // Test CreateOpKernel() + std::unique_ptr op( + CreateOpKernel(device_type, &device, cpu_allocator(), def, &status)); + EXPECT_TRUE(status.ok()) << status; + EXPECT_TRUE(op != nullptr); + if (op != nullptr) { + EXPECT_EQ(input_types.size(), op->num_inputs()); + EXPECT_EQ(0, op->num_outputs()); + } + + // Test SupportedDeviceTypesForNode() + DeviceTypeVector devices; + EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); + bool found = false; + for (DeviceType dt : devices) { + if (dt == device_type) { + found = true; + } + } + EXPECT_TRUE(found) << "Missing " << device_type << " from " + << devices.size() << " devices."; + + // In case the caller wants to use the OpKernel + return op; + } + + void ExpectFailure(const string& op_type, DeviceType device_type, + const std::vector& attrs, error::Code code) { + Status status; + const NodeDef def = CreateNodeDef(op_type, attrs); + Env* env = Env::Default(); + DeviceBase device(env); + + // Test CreateOpKernel(). + std::unique_ptr op( + CreateOpKernel(device_type, &device, cpu_allocator(), def, &status)); + EXPECT_TRUE(op == nullptr); + EXPECT_FALSE(status.ok()); + if (!status.ok()) { + LOG(INFO) << "Status message: " << status.error_message(); + EXPECT_EQ(code, status.code()); + + // Test SupportedDeviceTypesForNode(). + DeviceTypeVector devices; + if (errors::IsNotFound(status)) { + EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); + for (DeviceType dt : devices) { + EXPECT_NE(dt, device_type); + } + } else { + Status status2 = + SupportedDeviceTypesForNode(DeviceTypes(), def, &devices); + EXPECT_EQ(status.code(), status2.code()); + } + } + } +}; + +REGISTER_OP("BuildCPU"); +REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderCPU) { + ExpectSuccess("BuildCPU", DEVICE_CPU, {}); + ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND); +} + +REGISTER_OP("BuildGPU"); +REGISTER_KERNEL_BUILDER(Name("BuildGPU").Device(DEVICE_GPU), DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderGPU) { + ExpectFailure("BuildGPU", DEVICE_CPU, {}, error::NOT_FOUND); + ExpectSuccess("BuildGPU", DEVICE_GPU, {}); +} + +REGISTER_OP("BuildBoth"); +REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_GPU), DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderBoth) { + ExpectSuccess("BuildBoth", DEVICE_CPU, {}); + ExpectSuccess("BuildBoth", DEVICE_GPU, {}); +} + +REGISTER_OP("BuildTypeAttr").Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("BuildTypeAttr") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderTypeAttr) { + ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"}); + ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_BOOL"}, + error::NOT_FOUND); + ExpectFailure("BuildTypeAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT); + ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|int|7"}, + error::INVALID_ARGUMENT); +} + +REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)"); +REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) { + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"}); + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"}); + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, + {"T|list(type)|[DT_BOOL, DT_BOOL]"}); + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"}, + error::NOT_FOUND); + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT); + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"}, + error::INVALID_ARGUMENT); +} + +REGISTER_OP("DuplicateKernel"); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU), + DummyKernel); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU), + DummyKernel); + +TEST_F(OpKernelBuilderTest, DuplicateKernel) { + const NodeDef ndef = CreateNodeDef("DuplicateKernel", {}); + DeviceTypeVector devs; + Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Multiple OpKernel registrations match NodeDef")); + + ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT); +} + +REGISTER_OP("DuplicateKernelForT").Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + DummyKernel); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, DuplicateKernelForT) { + const NodeDef ndef = + CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"}); + DeviceTypeVector devs; + Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Multiple OpKernel registrations match NodeDef")); + + ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"}, + error::INVALID_ARGUMENT); + ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_BOOL"}, + error::NOT_FOUND); +} + +REGISTER_OP("BadConstraint").Attr("dtype: type"); +REGISTER_KERNEL_BUILDER(Name("BadConstraint") + .Device(DEVICE_CPU) + // Mistake: "T" should be "dtype". + .TypeConstraint("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, BadConstraint) { + const NodeDef ndef = CreateNodeDef("BadConstraint", {}); + DeviceTypeVector devs; + Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("OpKernel 'BadConstraint' has constraint on attr " + "'T' not in NodeDef")); + + ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"}, + error::INVALID_ARGUMENT); +} + +class GetAttrKernel : public ::tensorflow::OpKernel { + public: + explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) { + string attr_name; + OP_REQUIRES_OK(context, context->GetAttr("attr_name", &attr_name)); + + status.emplace_back("s", context->GetAttr(attr_name, &s)); + status.emplace_back("s_list", context->GetAttr(attr_name, &s_list)); + status.emplace_back("i", context->GetAttr(attr_name, &i)); + status.emplace_back("i_list", context->GetAttr(attr_name, &i_list)); + status.emplace_back("i32", context->GetAttr(attr_name, &i32)); + status.emplace_back("i32_list", context->GetAttr(attr_name, &i32_list)); + status.emplace_back("f", context->GetAttr(attr_name, &f)); + status.emplace_back("f_list", context->GetAttr(attr_name, &f_list)); + status.emplace_back("b", context->GetAttr(attr_name, &b)); + status.emplace_back("b_list", context->GetAttr(attr_name, &b_list)); + status.emplace_back("type", context->GetAttr(attr_name, &type)); + status.emplace_back("type_list", context->GetAttr(attr_name, &type_list)); + status.emplace_back("type_vector", + context->GetAttr(attr_name, &type_vector)); + status.emplace_back("shape_proto", + context->GetAttr(attr_name, &shape_proto)); + status.emplace_back("shape_proto_list", + context->GetAttr(attr_name, &shape_proto_list)); + status.emplace_back("shape", context->GetAttr(attr_name, &shape)); + status.emplace_back("shape_list", context->GetAttr(attr_name, &shape_list)); + } + void Compute(::tensorflow::OpKernelContext* context) override {} + + void ExpectOk(std::initializer_list keys) { + for (const auto& key_status : status) { + // Only the status for keys in "keys" should be ok(). + bool in_keys = false; + for (const string& key : keys) { + if (key_status.first == key) { + in_keys = true; + } + } + EXPECT_EQ(in_keys, key_status.second.ok()) + << "key_status: " << key_status.first << ", " << key_status.second; + } + } + + string s; + std::vector s_list; + int64 i; + std::vector i_list; + int32 i32; + std::vector i32_list; + float f; + std::vector f_list; + bool b; + std::vector b_list; + DataType type; + std::vector type_list; + DataTypeVector type_vector; + TensorShapeProto shape_proto; + std::vector shape_proto_list; + TensorShape shape; + std::vector shape_list; + std::vector> status; +}; + +class GetAttrTest : public OpKernelBuilderTest {}; + +REGISTER_OP("GetAttrStringList") + .Attr("attr_name: string") + .Attr("a: list(string)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrStringList").Device(DEVICE_CPU), + GetAttrKernel); + +TEST_F(GetAttrTest, StringList) { + std::unique_ptr op_kernel = + ExpectSuccess("GetAttrStringList", DEVICE_CPU, + {"attr_name|string|'a'", "a|list(string)|['foo', 'bar']"}); + auto* get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"s_list"}); + EXPECT_EQ(std::vector({"foo", "bar"}), get_attr_kernel->s_list); + + op_kernel = ExpectSuccess("GetAttrStringList", DEVICE_CPU, + {"attr_name|string|'b'", "a|list(string)|['baz']"}); + get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({}); + EXPECT_TRUE(get_attr_kernel->s_list.empty()); +} + +REGISTER_OP("GetAttrInt") + .Attr("attr_name: string") + .Attr("a: int") + .Attr("b: list(int)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrInt").Device(DEVICE_CPU), GetAttrKernel); + +TEST_F(GetAttrTest, Int) { + std::unique_ptr op_kernel = ExpectSuccess( + "GetAttrInt", DEVICE_CPU, + {"attr_name|string|'a'", "a|int|35", "b|list(int)|[-1, 2, -4]"}); + auto* get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"i", "i32"}); + EXPECT_EQ(35, get_attr_kernel->i); + EXPECT_EQ(35, get_attr_kernel->i32); + + op_kernel = ExpectSuccess( + "GetAttrInt", DEVICE_CPU, + {"attr_name|string|'b'", "a|int|35", "b|list(int)|[-1, 2, -4]"}); + get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"i_list", "i32_list"}); + EXPECT_EQ(std::vector({-1, 2, -4}), get_attr_kernel->i_list); + EXPECT_EQ(std::vector({-1, 2, -4}), get_attr_kernel->i32_list); + + // 8589934592 == 2^33, too big to fit in an int32 + op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU, + {"attr_name|string|'a'", "a|int|8589934592", + "b|list(int)|[-8589934592]"}); + get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"i"}); // no i32 + EXPECT_EQ(8589934592ll, get_attr_kernel->i); + for (const auto& key_status : get_attr_kernel->status) { + if (key_status.first == "i32") { + EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code()); + EXPECT_EQ("Attr a has value 8589934592 out of range for an int32", + key_status.second.error_message()); + } + } + + op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU, + {"attr_name|string|'b'", "a|int|8589934592", + "b|list(int)|[-8589934592]"}); + get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"i_list"}); // no i32_list + EXPECT_EQ(std::vector({-8589934592ll}), get_attr_kernel->i_list); + for (const auto& key_status : get_attr_kernel->status) { + if (key_status.first == "i32_list") { + EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code()); + EXPECT_EQ("Attr b has value -8589934592 out of range for an int32", + key_status.second.error_message()); + } + } +} + +REGISTER_OP("GetAttrShape") + .Attr("attr_name: string") + .Attr("a: shape") + .Attr("b: list(shape)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrShape").Device(DEVICE_CPU), GetAttrKernel); + +TEST_F(GetAttrTest, Shape) { + std::unique_ptr op_kernel = ExpectSuccess( + "GetAttrShape", DEVICE_CPU, + {"attr_name|string|'a'", "a|shape|{ dim { size: 3 } }", + "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"}); + auto* get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"shape", "shape_proto"}); + EXPECT_EQ(get_attr_kernel->shape_proto.ShortDebugString(), "dim { size: 3 }"); + EXPECT_EQ("[3]", get_attr_kernel->shape.ShortDebugString()); + + op_kernel = ExpectSuccess( + "GetAttrShape", DEVICE_CPU, + {"attr_name|string|'b'", "a|shape|{ dim { size: 3 } }", + "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"}); + get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"shape_list", "shape_proto_list"}); + ASSERT_EQ(2, get_attr_kernel->shape_proto_list.size()); + EXPECT_EQ(get_attr_kernel->shape_proto_list[0].ShortDebugString(), + "dim { size: 2 }"); + EXPECT_EQ(get_attr_kernel->shape_proto_list[1].ShortDebugString(), + "dim { size: 4 }"); + ASSERT_EQ(2, get_attr_kernel->shape_list.size()); + EXPECT_EQ("[2]", get_attr_kernel->shape_list[0].ShortDebugString()); + EXPECT_EQ("[4]", get_attr_kernel->shape_list[1].ShortDebugString()); +} + +REGISTER_OP("GetAttrType").Attr("attr_name: string").Attr("a: type"); +REGISTER_KERNEL_BUILDER(Name("GetAttrType").Device(DEVICE_CPU), GetAttrKernel); + +TEST_F(GetAttrTest, Type) { + std::unique_ptr op_kernel = ExpectSuccess( + "GetAttrType", DEVICE_CPU, {"attr_name|string|'a'", "a|type|DT_FLOAT"}); + auto* get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"type"}); + EXPECT_EQ(DT_FLOAT, get_attr_kernel->type); +} + +REGISTER_OP("GetAttrTypeList").Attr("attr_name: string").Attr("a: list(type)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrTypeList").Device(DEVICE_CPU), + GetAttrKernel); + +TEST_F(GetAttrTest, TypeList) { + std::unique_ptr op_kernel = ExpectSuccess( + "GetAttrTypeList", DEVICE_CPU, + {"attr_name|string|'a'", "a|list(type)|[DT_INT32, DT_BOOL]"}); + auto* get_attr_kernel = static_cast(op_kernel.get()); + + get_attr_kernel->ExpectOk({"type_list", "type_vector"}); + ASSERT_EQ(2, get_attr_kernel->type_list.size()); + EXPECT_EQ(DT_INT32, get_attr_kernel->type_list[0]); + EXPECT_EQ(DT_BOOL, get_attr_kernel->type_list[1]); + ASSERT_EQ(2, get_attr_kernel->type_vector.size()); + EXPECT_EQ(DT_INT32, get_attr_kernel->type_vector[0]); + EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]); +} + +REGISTER_OP("HostMemoryTest") + .Input("a: float") + .Input("b: T") + .Input("c: N * string") + .Output("o: N * T") + .Attr("T: type") + .Attr("N: int"); +REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("HostMemoryTest") + .Device(DEVICE_GPU) + .HostMemory("a") + .HostMemory("c") + .HostMemory("o"), + DummyKernel); + +TEST(MemoryTypesForNode, Simple) { + NodeDef node_def; + ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest") + .Input(FakeInput()) + .Input(FakeInput(DT_BOOL)) + .Input(FakeInput(3)) + .Finalize(&node_def)); + MemoryTypeVector input, output; + + EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def, + &input, &output)); + EXPECT_EQ(MemoryTypeVector(5, DEVICE_MEMORY), input); + EXPECT_EQ(MemoryTypeVector(3, DEVICE_MEMORY), output); + + EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def, + &input, &output)); + EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, + HOST_MEMORY, HOST_MEMORY}), + input); + EXPECT_EQ(MemoryTypeVector(3, HOST_MEMORY), output); +} + +class BaseKernel : public ::tensorflow::OpKernel { + public: + explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(::tensorflow::OpKernelContext* context) override {} + virtual int Which() const = 0; +}; + +template +class LabeledKernel : public BaseKernel { + public: + using BaseKernel::BaseKernel; + int Which() const override { return WHICH; } +}; + +class LabelTest : public OpKernelBuilderTest {}; + +REGISTER_OP("LabeledKernel"); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU), + LabeledKernel<0>); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("one"), + LabeledKernel<1>); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"), + LabeledKernel<2>); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"), + LabeledKernel<3>); + +TEST_F(LabelTest, Default) { + std::unique_ptr op_kernel = + ExpectSuccess("LabeledKernel", DEVICE_CPU, {}); + auto* get_labeled_kernel = static_cast(op_kernel.get()); + EXPECT_EQ(0, get_labeled_kernel->Which()); +} + +TEST_F(LabelTest, Specified) { + std::unique_ptr op_kernel = + ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"}); + auto* get_labeled_kernel = static_cast(op_kernel.get()); + EXPECT_EQ(1, get_labeled_kernel->Which()); +} + +TEST_F(LabelTest, Duplicate) { + ExpectFailure("LabeledKernel", DEVICE_CPU, {"_kernel|string|'dupe'"}, + error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc new file mode 100644 index 0000000000..a39bebd854 --- /dev/null +++ b/tensorflow/core/framework/op_segment.cc @@ -0,0 +1,86 @@ +#include "tensorflow/core/framework/op_segment.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +OpSegment::Item::~Item() { + for (auto kv : name_kernel) delete kv.second; +} + +OpSegment::OpSegment() {} + +OpSegment::~OpSegment() { + for (auto kv : sessions_) delete kv.second; +} + +Status OpSegment::FindOrCreate(const string& session_handle, + const string& node_name, OpKernel** kernel, + CreateKernelFn create_fn) { + { + mutex_lock l(mu_); + auto item = gtl::FindPtrOrNull(sessions_, session_handle); + if (item == nullptr) { + return errors::NotFound("Session ", session_handle, " is not found."); + } + *kernel = gtl::FindPtrOrNull(item->name_kernel, node_name); + if (*kernel != nullptr) { + return Status::OK(); + } + } + Status s = create_fn(kernel); + if (!s.ok()) { + LOG(ERROR) << "Create kernel failed: " << s; + return s; + } + { + mutex_lock l(mu_); + auto item = gtl::FindPtrOrNull(sessions_, session_handle); + if (item == nullptr) { + return errors::NotFound("Session ", session_handle, " is not found."); + } + OpKernel** p_kernel = &(item->name_kernel[node_name]); + if (*p_kernel == nullptr) { + *p_kernel = *kernel; // Inserts 'kernel' in the map. + } else { + delete *kernel; + *kernel = *p_kernel; + } + } + return Status::OK(); +} + +void OpSegment::AddHold(const string& session_handle) { + mutex_lock l(mu_); + Item** item = &sessions_[session_handle]; + if (*item == nullptr) { + *item = new Item; // num_holds == 1 + } else { + ++((*item)->num_holds); + } +} + +void OpSegment::RemoveHold(const string& session_handle) { + Item* item = nullptr; + { + mutex_lock l(mu_); + auto siter = sessions_.find(session_handle); + if (siter == sessions_.end()) { + VLOG(1) << "Session " << session_handle << " is not found."; + return; + } + item = siter->second; + if (--(item->num_holds) > 0) { + return; + } else { + sessions_.erase(siter); + } + } + delete item; +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h new file mode 100644 index 0000000000..55249d2a38 --- /dev/null +++ b/tensorflow/core/framework/op_segment.h @@ -0,0 +1,67 @@ +#ifndef TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_ +#define TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// OpSegment keeps track of OpKernels registered for sessions running +// on a device. +// +// The implementation maintains a two-level map. The 1st level maps +// session handle to the map of registered OpKernels. The 2nd level +// map maps node names to instantiated OpKernel objects. +// +// Each 2-nd level map is reference-counted and the caller can call +// AddHold to obtain a reference on all kernels of a session and +// ensure these kernels are alive until a corresponding RemoveHold is +// called on the same session. +class OpSegment { + public: + OpSegment(); + ~OpSegment(); + + // A hold can be placed on a session, preventing all its kernels + // from being deleted. + void AddHold(const string& session_handle); + void RemoveHold(const string& session_handle); + + // If the kernel for "node_name" has been created in the + // "session_handle", returns the existing op kernel in "*kernel". + // Otherwise, creates the kernel by calling create_fn(), cache it, + // and returns it in "*kernel". If create_fn() fails, returns the + // error. + // + // OpSegment keeps the ownership of the returned "*kernel". + typedef std::function CreateKernelFn; + Status FindOrCreate(const string& session_handle, const string& node_name, + OpKernel** kernel, CreateKernelFn create_fn); + + private: + // op name -> OpKernel + typedef std::unordered_map KernelMap; + struct Item { + int num_holds = 1; // Num of holds put on the session. + KernelMap name_kernel; // op name -> kernel. + ~Item(); + }; + + // session handle -> item. + // Session handles are produced by strings::FpToString() + typedef std::unordered_map SessionMap; + + mutable mutex mu_; + SessionMap sessions_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(OpSegment); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_ diff --git a/tensorflow/core/framework/op_segment_test.cc b/tensorflow/core/framework/op_segment_test.cc new file mode 100644 index 0000000000..6297718df8 --- /dev/null +++ b/tensorflow/core/framework/op_segment_test.cc @@ -0,0 +1,142 @@ +#include "tensorflow/core/framework/op_segment.h" + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { + +class OpSegmentTest : public ::testing::Test { + protected: + DeviceBase device_; + std::vector int32_nodedefs_; + std::vector float_nodedefs_; + + OpSegmentTest() : device_(Env::Default()) { + RequireDefaultOps(); + for (int i = 0; i < 10; ++i) { + NodeDef def; + TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul") + .Input("x", 0, DT_INT32) + .Input("y", 0, DT_INT32) + .Finalize(&def)); + int32_nodedefs_.push_back(def); + TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul") + .Input("x", 0, DT_FLOAT) + .Input("y", 0, DT_FLOAT) + .Finalize(&def)); + float_nodedefs_.push_back(def); + } + } + + void ValidateOpAndTypes(OpKernel* op, const NodeDef& expected, DataType dt) { + ASSERT_NE(op, nullptr); + EXPECT_EQ(expected.DebugString(), op->def().DebugString()); + EXPECT_EQ(2, op->num_inputs()); + EXPECT_EQ(dt, op->input_type(0)); + EXPECT_EQ(dt, op->input_type(1)); + EXPECT_EQ(1, op->num_outputs()); + EXPECT_EQ(dt, op->output_type(0)); + } + + OpSegment::CreateKernelFn GetFn(const NodeDef* ndef) { + return [this, ndef](OpKernel** kernel) { + Status s; + auto created = + CreateOpKernel(DEVICE_CPU, &device_, cpu_allocator(), *ndef, &s); + if (s.ok()) { + *kernel = created.release(); + } + return s; + }; + } +}; + +TEST_F(OpSegmentTest, Basic) { + OpSegment opseg; + OpKernel* op; + + opseg.AddHold("A"); + opseg.AddHold("B"); + for (int i = 0; i < 10; ++i) { + // Register in session A. + auto* ndef = &float_nodedefs_[i]; + EXPECT_OK(opseg.FindOrCreate("A", ndef->name(), &op, GetFn(ndef))); + ValidateOpAndTypes(op, *ndef, DT_FLOAT); + + // Register in session B. + ndef = &int32_nodedefs_[i]; + EXPECT_OK(opseg.FindOrCreate("B", ndef->name(), &op, GetFn(ndef))); + ValidateOpAndTypes(op, *ndef, DT_INT32); + } + + auto reterr = [](OpKernel** kernel) { + return errors::Internal("Should not be called"); + }; + for (int i = 0; i < 10; ++i) { + // Lookup op in session A. + EXPECT_OK(opseg.FindOrCreate("A", strings::StrCat("op", i), &op, reterr)); + ValidateOpAndTypes(op, float_nodedefs_[i], DT_FLOAT); + + // Lookup op in session B. + EXPECT_OK(opseg.FindOrCreate("B", strings::StrCat("op", i), &op, reterr)); + ValidateOpAndTypes(op, int32_nodedefs_[i], DT_INT32); + } + + opseg.RemoveHold("A"); + opseg.RemoveHold("B"); +} + +TEST_F(OpSegmentTest, SessionNotFound) { + OpSegment opseg; + OpKernel* op; + NodeDef def = float_nodedefs_[0]; + Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); + EXPECT_TRUE(errors::IsNotFound(s)) << s; +} + +TEST_F(OpSegmentTest, CreateFailure) { + OpSegment opseg; + OpKernel* op; + NodeDef def = float_nodedefs_[0]; + def.set_op("nonexistop"); + opseg.AddHold("A"); + Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); + EXPECT_TRUE(errors::IsNotFound(s)) << s; + opseg.RemoveHold("A"); +} + +TEST_F(OpSegmentTest, AddRemoveHolds) { + OpSegment opseg; + OpKernel* op; + const auto& ndef = int32_nodedefs_[0]; + + // No op. + opseg.RemoveHold("null"); + + // Thread1 register the op and wants to ensure it alive. + opseg.AddHold("foo"); + EXPECT_OK(opseg.FindOrCreate("foo", ndef.name(), &op, GetFn(&ndef))); + + // Thread2 starts some execution needs "op" to be alive. + opseg.AddHold("foo"); + + // Thread1 clears session "foo". E.g., a master sends CleanupGraph + // before an execution finishes. + opseg.RemoveHold("foo"); + + // Thread2 should still be able to access "op". + ValidateOpAndTypes(op, ndef, DT_INT32); + + // Thread2 then remove its hold on "foo". + opseg.RemoveHold("foo"); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/queue_interface.h b/tensorflow/core/framework/queue_interface.h new file mode 100644 index 0000000000..a765c211cb --- /dev/null +++ b/tensorflow/core/framework/queue_interface.h @@ -0,0 +1,77 @@ +#ifndef TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_ +#define TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +// All implementations must be thread-safe. +class QueueInterface : public ResourceBase { + public: + typedef std::vector Tuple; + typedef AsyncOpKernel::DoneCallback DoneCallback; + typedef std::function CallbackWithTuple; + + virtual Status ValidateTuple(const Tuple& tuple) = 0; + virtual Status ValidateManyTuple(const Tuple& tuple) = 0; + + // Stashes a function object for future execution, that will eventually + // enqueue the tuple of tensors into the queue, and returns immediately. The + // function object is guaranteed to call 'callback'. + virtual void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) = 0; + + // Same as above, but the component tensors are sliced along the 0th dimension + // to make multiple queue-element components. + virtual void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) = 0; + + // Stashes a function object for future execution, that will eventually + // dequeue an element from the queue and call 'callback' with that tuple + // element as argument. + virtual void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) = 0; + + // Same as above, but the stashed function object will attempt to dequeue + // num_elements items. + virtual void TryDequeueMany(int num_elements, OpKernelContext* ctx, + CallbackWithTuple callback) = 0; + + // Signals that no more elements will be enqueued, and optionally + // cancels pending Enqueue(Many) operations. + // + // After calling this function, subsequent calls to Enqueue(Many) + // will fail. If `cancel_pending_enqueues` is true, all pending + // calls to Enqueue(Many) will fail as well. + // + // After calling this function, all current and subsequent calls to + // Dequeue(Many) will fail instead of blocking (though they may + // succeed if they can be satisfied by the elements in the queue at + // the time it was closed). + virtual void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, + DoneCallback callback) = 0; + + // Assuming *this represents a shared queue, verify that it matches + // another instantiation indicated by node_def. + virtual Status MatchesNodeDef(const NodeDef& node_def) = 0; + + // Returns the number of elements in the queue. + virtual int32 size() = 0; + + virtual const DataTypeVector& component_dtypes() const = 0; + + string DebugString() override { return "A queue"; } + + protected: + virtual ~QueueInterface() {} +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_ diff --git a/tensorflow/core/framework/reader_interface.h b/tensorflow/core/framework/reader_interface.h new file mode 100644 index 0000000000..b307c37f01 --- /dev/null +++ b/tensorflow/core/framework/reader_interface.h @@ -0,0 +1,66 @@ +#ifndef TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_ +#define TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_ + +#include +#include +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +class QueueInterface; +class ReaderInterface; + +// Readers are the mechanism for reading records from files in +// TensorFlow graphs. Each supported file format has a corresponding +// ReaderInterface descendant and a corresponding Op & OpKernel +// (implemented using ReaderOpKernel from reader_op_kernel.h). +// +// To use a Reader, you first encode "work" (some string, typically a +// filename) in the Reader's "work queue". It then processes the +// "work" (reading records from the file), to produce key/value +// strings. The methods of this class are called by ReaderFoo ops, +// so see ../ops/io_ops.cc for detailed descriptions. +// +// All descendants of this class must be thread-safe. +// +// See the design document here: +// https://docs.google.com/document/d/1UAgZOoeehYr20TdzW2CoZ30V-aqQphU4SwKXsW7eJv4/edit# + +// TODO(josh11b): Switch this to Async. +class ReaderInterface : public ResourceBase { + public: + // Read a single record into *key / *value. May get more work from + // *queue if the current work is complete. Sets the status on + // *context with an OutOfRange Status if the the current work is + // complete and the queue is done (closed and empty). + // This method may block. + virtual void Read(QueueInterface* queue, string* key, string* value, + OpKernelContext* context) = 0; + + // Restore this reader to its newly-constructed state. + virtual Status Reset() = 0; + + // Accessors + virtual int64 NumRecordsProduced() = 0; + virtual int64 NumWorkUnitsCompleted() = 0; + + // -- Serialization/Restoration support -- + // Not all readers will support saving and restoring state. + virtual Status SerializeState(string* state) = 0; + // Note: Must Reset on error. + virtual Status RestoreState(const string& state) = 0; + + string DebugString() override { return "a reader"; } + + protected: + virtual ~ReaderInterface() {} +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_ diff --git a/tensorflow/core/framework/reader_op_kernel.cc b/tensorflow/core/framework/reader_op_kernel.cc new file mode 100644 index 0000000000..719f27d94b --- /dev/null +++ b/tensorflow/core/framework/reader_op_kernel.cc @@ -0,0 +1,39 @@ +#include "tensorflow/core/framework/reader_op_kernel.h" + +namespace tensorflow { + +ReaderOpKernel::ReaderOpKernel(OpKernelConstruction* context) + : OpKernel(context), have_handle_(false) { + OP_REQUIRES_OK(context, context->allocate_persistent( + tensorflow::DT_STRING, + tensorflow::TensorShape({2}), &handle_, nullptr)); +} + +ReaderOpKernel::~ReaderOpKernel() { + if (have_handle_ && cinfo_.resource_is_private_to_kernel()) { + TF_CHECK_OK(cinfo_.resource_manager()->Delete( + cinfo_.container(), cinfo_.name())); + } +} + +void ReaderOpKernel::Compute(OpKernelContext* ctx) { + mutex_lock l(mu_); + if (!have_handle_) { + OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), false)); + ReaderInterface* reader; + OP_REQUIRES_OK(ctx, + cinfo_.resource_manager()->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &reader, + [this](ReaderInterface** ret) { + *ret = factory_(); + return Status::OK(); + })); + auto h = handle_.AccessTensor(ctx)->flat(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + have_handle_ = true; + } + ctx->set_output_ref(0, &mu_, handle_.AccessTensor(ctx)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/reader_op_kernel.h b/tensorflow/core/framework/reader_op_kernel.h new file mode 100644 index 0000000000..8e5cc50c9b --- /dev/null +++ b/tensorflow/core/framework/reader_op_kernel.h @@ -0,0 +1,42 @@ +#ifndef TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_ +#define TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/reader_interface.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Implementation for ops providing a Reader. +class ReaderOpKernel : public OpKernel { + public: + explicit ReaderOpKernel(OpKernelConstruction* context); + ~ReaderOpKernel() override; + + void Compute(OpKernelContext* context) override; + + // Must be called by descendants before the first call to Compute() + // (typically called during construction). factory must return a + // ReaderInterface descendant allocated with new that ReaderOpKernel + // will take ownership of. + void SetReaderFactory(std::function factory) { + mutex_lock l(mu_); + DCHECK(!have_handle_); + factory_ = factory; + } + + private: + mutex mu_; + bool have_handle_ GUARDED_BY(mu_); + PersistentTensor handle_ GUARDED_BY(mu_); + ContainerInfo cinfo_; + std::function factory_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_ diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h new file mode 100644 index 0000000000..18473aea2e --- /dev/null +++ b/tensorflow/core/framework/register_types.h @@ -0,0 +1,90 @@ +#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ +// This file is used by cuda code and must remain compilable by nvcc. + +#include "tensorflow/core/platform/port.h" + +// Macros to apply another macro to lists of supported types. If you change +// the lists of types, please also update the list in types.cc. +// +// See example uses of these macros in core/ops. +// +// +// Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple +// times by passing each invocation a data type supported by TensorFlow. +// +// The different variations pass different subsets of the types. +// TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow. +// The set of types depends on the compilation platform. +//. +// This can be used to register a different template instantiation of +// an OpKernel for different signatures, e.g.: +/* + #define REGISTER_PARTITION(type) \ + REGISTER_TF_OP_KERNEL("partition", DEVICE_CPU, #type ", int32", \ + PartitionOp); + TF_CALL_ALL_TYPES(REGISTER_PARTITION) + #undef REGISTER_PARTITION +*/ + +#ifndef __ANDROID__ + +// Call "m" for all number types that support the comparison operations "<" and +// ">". +#define TF_CALL_REAL_NUMBER_TYPES(m) \ + m(float); \ + m(double); \ + m(int64); \ + m(int32); \ + m(uint8); \ + m(int16); \ + m(int8) + +#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ + m(float); \ + m(double); \ + m(int64); \ + m(uint8); \ + m(int16); \ + m(int8) + +// Call "m" for all number types, including complex64. +#define TF_CALL_NUMBER_TYPES(m) \ + TF_CALL_REAL_NUMBER_TYPES(m); \ + m(complex64) + +#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m); \ + m(complex64) + +// Call "m" on all types. +#define TF_CALL_ALL_TYPES(m) \ + TF_CALL_NUMBER_TYPES(m); \ + m(bool); \ + m(string) + +// Call "m" on all types supported on GPU. +#define TF_CALL_GPU_NUMBER_TYPES(m) \ + m(float); \ + m(double) + +#else // __ANDROID__ + +#define TF_CALL_REAL_NUMBER_TYPES(m) \ + m(float); \ + m(int32) + +#define TF_CALL_NUMBER_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m) + +#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) m(float) + +#define TF_CALL_NUMBER_TYPES_NO_INT32(m) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) + +#define TF_CALL_ALL_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m) + +// Maybe we could put an empty macro here for Android? +#define TF_CALL_GPU_NUMBER_TYPES(m) m(float) + +#endif // __ANDROID__ + +#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc new file mode 100644 index 0000000000..7f551ea65f --- /dev/null +++ b/tensorflow/core/framework/rendezvous.cc @@ -0,0 +1,263 @@ +#include "tensorflow/core/framework/rendezvous.h" + +#include +#include + +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +/* static */ +string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation, + const string& dst_device, const string& name, + const FrameAndIter& frame_iter) { + // NOTE: ';' is not used in the device name's job name. + // + // We include both sender and receiver in the key to facilitate + // debugging. For correctness, we only need to encode the receiver. + // + // "src_incarnation" is used to distinguish a worker when it + // restarts. + return strings::StrCat(src_device, ";", strings::FpToString(src_incarnation), + ";", dst_device, ";", name, ";", frame_iter.frame_id, + ":", frame_iter.iter_id); +} + +/* static */ +Status Rendezvous::ParseKey(const string& key, ParsedKey* out) { + // TODO(zhifengc): This code is not fast enough. + std::vector parts = str_util::Split(key, ';'); + if (parts.size() == 5 && + DeviceNameUtils::ParseFullName(parts[0], &out->src) && + strings::StringToFp(parts[1], &out->src_incarnation) && + DeviceNameUtils::ParseFullName(parts[2], &out->dst) && + !parts[3].empty()) { + out->src_device = parts[0]; + out->dst_device = parts[2]; + out->edge_name = parts[3]; + return Status::OK(); + } + return errors::InvalidArgument("Invalid rendezvous key: ", key); +} + +Rendezvous::~Rendezvous() {} + +Status Rendezvous::Recv(const string& key, const Args& recv_args, Tensor* val, + bool* is_dead) { + Status ret; + Notification n; + RecvAsync(key, recv_args, + [&ret, &n, val, is_dead](const Status& s, const Args& send_args, + const Args& recv_args, const Tensor& v, + const bool dead) { + ret = s; + *val = v; + *is_dead = dead; + n.Notify(); + }); + n.WaitForNotification(); + return ret; +} + +class LocalRendezvousImpl : public Rendezvous { + public: + explicit LocalRendezvousImpl(bool tolerate_dup_recv) + : tolerate_dup_recv_(tolerate_dup_recv) {} + + Status Send(const string& key, const Args& send_args, const Tensor& val, + const bool is_dead) override { + VLOG(2) << "Send " << this << " " << key; + DoneCallback waiter = nullptr; + Args recv_args; + { + mutex_lock l(mu_); + if (!status_.ok()) { + return status_; + } + Item* item = nullptr; + Table::iterator iter = table_.find(key); + if (iter == table_.end()) { + // There is no waiter for this message. Insert the message + // into the waiters table. The waiter will pick it up when + // arrives. + item = new Item; + item->waiter = nullptr; + item->value = val; + item->is_dead = is_dead; + if (send_args.device_context) { + send_args.device_context->Ref(); + item->send_dev_context = send_args.device_context; + } + item->recv_dev_context = nullptr; + + // The allocator attributes of item->value. + item->send_alloc_attrs = send_args.alloc_attrs; + + CHECK(table_.insert({key, item}).second); + return Status::OK(); + } else { + item = iter->second; + if (item->waiter == nullptr) { + // There is already a message in the table under the key. + // Should not happen unless it has a waiter. + return errors::Aborted("Duplicated send: ", key); + } + // Mark item as complete. + item->has_been_recvd = true; + waiter = item->waiter; + item->waiter = nullptr; + // The ref on recv_dev_context transfers below. + recv_args.device_context = item->recv_dev_context; + recv_args.alloc_attrs = item->recv_alloc_attrs; + item->recv_dev_context = nullptr; + if (tolerate_dup_recv_) { + item->value = val; + item->is_dead = is_dead; + if (send_args.device_context) { + send_args.device_context->Ref(); + item->send_dev_context = send_args.device_context; + } + item->send_alloc_attrs = send_args.alloc_attrs; + } + } + } // mutex + // Notify the waiter by invoking its done closure, outside scope + // of the table lock. + waiter(Status::OK(), send_args, recv_args, val, is_dead); + if (recv_args.device_context) recv_args.device_context->Unref(); + return Status::OK(); + } + + void RecvAsync(const string& key, const Args& recv_args, + DoneCallback done) override { + VLOG(2) << "Recv " << this << " " << key; + mu_.lock(); + if (!status_.ok()) { + // Rendezvous has been aborted. + Status s = status_; + mu_.unlock(); + done(s, Args(), recv_args, Tensor(), false); + return; + } + Table::iterator iter = table_.find(key); + if (iter != table_.end()) { + Item* item = iter->second; + if (item->has_been_recvd && !tolerate_dup_recv_) { + mu_.unlock(); + done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args, + Tensor(), false); + } else if (item->waiter == nullptr || tolerate_dup_recv_) { + // A message has already arrived and is stored in the table + // under this key. Consumes the message and invokes the done + // closure. + Tensor v = item->value; + if (!tolerate_dup_recv_) { + item->value = Tensor(); + } + item->has_been_recvd = true; + // Before dropping the table lock, capture the item values. + // DeviceContext is only non-null for non-CPU devices. + // If we capture the send_dev_context, we need to hold a ref on + // it. Our caller will have a ref on the recv_dev_context, + // which is not in our table. + DeviceContext* send_dev_context = item->send_dev_context; + if (send_dev_context) send_dev_context->Ref(); + bool is_dead = item->is_dead; + mu_.unlock(); + Args send_args; + send_args.device_context = item->send_dev_context; + send_args.alloc_attrs = item->send_alloc_attrs; + done(Status::OK(), send_args, recv_args, v, is_dead); + if (send_dev_context) send_dev_context->Unref(); + } else { + // Already have a waiter in the waiters table under this key, + // which should not happen. + mu_.unlock(); + done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args, + Tensor(), false); + } + return; + } + // Waiting for a message that has not arrived yet. Insert into the + // waiting table. The done closure will be invoked when the + // message arrives. + Item* item = new Item; + item->waiter = done; + if (recv_args.device_context) { + item->recv_dev_context = recv_args.device_context; + item->recv_alloc_attrs = recv_args.alloc_attrs; + item->recv_dev_context->Ref(); + } + CHECK(table_.insert({key, item}).second); + mu_.unlock(); + return; + } + + void StartAbort(const Status& status) override { + CHECK(!status.ok()); + std::vector items; + { + mutex_lock l(mu_); + if (!status_.ok()) return; + status_ = status; + items.reserve(table_.size()); + for (const auto& p : table_) items.push_back(p.second); + table_.clear(); + } + for (Item* item : items) { + if (item->waiter != nullptr) { + item->waiter(status, Args(), Args(), Tensor(), false); + } + delete item; + } + } + + private: + typedef LocalRendezvousImpl ME; + const bool tolerate_dup_recv_; + + struct Item { + DoneCallback waiter = nullptr; + Tensor value; + bool is_dead = false; + bool has_been_recvd = false; + DeviceContext* send_dev_context = nullptr; + DeviceContext* recv_dev_context = nullptr; + AllocatorAttributes send_alloc_attrs; + AllocatorAttributes recv_alloc_attrs; + + ~Item() { + if (send_dev_context) { + send_dev_context->Unref(); + } + if (recv_dev_context) { + recv_dev_context->Unref(); + } + } + }; + typedef std::unordered_map Table; + + // TODO(zhifengc): shard table_. + mutex mu_; + Table table_ GUARDED_BY(mu_); + Status status_; + + ~LocalRendezvousImpl() override { + for (auto i : table_) { + delete i.second; + } + } + + TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl); +}; + +Rendezvous* NewLocalRendezvous(bool tolerate_dup_recv) { + return new LocalRendezvousImpl(tolerate_dup_recv); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h new file mode 100644 index 0000000000..94fbfb2523 --- /dev/null +++ b/tensorflow/core/framework/rendezvous.h @@ -0,0 +1,102 @@ +#ifndef TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ +#define TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ + +#include + +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// A Rendezvous is an abstraction for passing a Tensor +// from a producer to a consumer, where the consumer may safely +// request the Tensor before or after it has been produced. A +// producer never blocks when using a Rendezvous. A consumer has the +// choice of making a blocking call or providing a callback: in either +// case, the consumer receives the Tensor as soon as it is available. +// +// A Rendezvous key encodes a single pair. It is +// an error to call Send() or Recv*() more than once with the same +// key. +class Rendezvous : public core::RefCounted { + public: + struct Args { + DeviceContext* device_context = nullptr; + AllocatorAttributes alloc_attrs; + }; + + // Constructs a rendezvouz key for the tensor of "name" sent from + // "src_device" to "dst_device". The tensor is generated in the frame + // and iteration specified by "frame_iter". + static string CreateKey(const string& src_device, uint64 src_incarnation, + const string& dst_device, const string& name, + const FrameAndIter& frame_iter); + + // Parses the key constructed by CreateKey and parse src/dst device + // names into structures respectively. + struct ParsedKey { + string src_device; + DeviceNameUtils::ParsedName src; + uint64 src_incarnation = 0; + string dst_device; + DeviceNameUtils::ParsedName dst; + string edge_name; + }; + static Status ParseKey(const string& key, ParsedKey* out); + + // The caller is a tensor producer and it sends a message (a tensor + // "val" and a bool "is_dead") under the given "key". + // + // {val, is_dead} is bundled as a message sent and received. + // Typically, is_dead is set by some control flow nodes + // (e.g., a not-take branch). args is passed by Send to the + // Recv function to communicate any information that the Recv + // function might need. This is typically only necessary for + // Send/Recv on the same worker. + // + // Send() never blocks. + virtual Status Send(const string& key, const Args& args, const Tensor& val, + const bool is_dead) = 0; + + // Callback provided by a tensor consumer waiting on the rendezvous. + // It will be invoked when the tensor is available, or when a non-OK + // status arises in the production of that tensor. It also gets + // two Rendezvous::Args, one provided by the sender, the other by the + // receiver, which may be needed when a non-CPU device is in use + // by either side. + typedef std::function DoneCallback; + + virtual void RecvAsync(const string& key, const Args& args, + DoneCallback done) = 0; + + // Synchronous wrapper for RecvAsync. + Status Recv(const string& key, const Args& args, Tensor* val, bool* is_dead); + + // Aborts all pending and future Send/Recv with the given "status". + // + // StartAbort() does not wait for ongoing calls to finish. + // REQUIRES: !status.ok() + virtual void StartAbort(const Status& status) = 0; + + protected: + ~Rendezvous() override; +}; + +// Returns a Rendezvous instance that is limited to use only by +// producers and consumers in the local process. The caller assumes +// ownership of one Ref() on the returned object. +// +// If "tolerate_dup_recv" is true, then the Rendezvous will retain +// already Recv'd values and make them available to duplicate Recv +// calls. This may be useful if the RPC layer is not reliable, but +// comes at the cost of higher memory consumption. +Rendezvous* NewLocalRendezvous(bool tolerate_dup_recv = false); + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc new file mode 100644 index 0000000000..32011a468f --- /dev/null +++ b/tensorflow/core/framework/rendezvous_test.cc @@ -0,0 +1,314 @@ +#include "tensorflow/core/framework/rendezvous.h" + +#include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +TEST(RendezvousTest, Key) { + const string key = Rendezvous::CreateKey( + "/job:mnist/replica:1/task:2/CPU:0", 7890, + "/job:mnist/replica:1/task:2/GPU:0", "var0", FrameAndIter(0, 0)); + EXPECT_EQ(key, + "/job:mnist/replica:1/task:2/CPU:0;" + "0000000000001ed2;" // 7890 = 0x1ed2 + "/job:mnist/replica:1/task:2/GPU:0;" + "var0;" + "0:0"); + Rendezvous::ParsedKey parsed; + EXPECT_OK(Rendezvous::ParseKey(key, &parsed)); + EXPECT_EQ(parsed.src_device, "/job:mnist/replica:1/task:2/CPU:0"); + EXPECT_EQ(parsed.src_incarnation, 7890); + EXPECT_EQ(parsed.src.type, "CPU"); + EXPECT_EQ(parsed.dst_device, "/job:mnist/replica:1/task:2/GPU:0"); + EXPECT_EQ(parsed.dst.type, "GPU"); + + EXPECT_FALSE(Rendezvous::ParseKey("foo;bar;baz", &parsed).ok()); + EXPECT_FALSE(Rendezvous::ParseKey("/job:mnist/replica:1/task:2/CPU:0;" + "/job:mnist/replica:1/task:2/GPU:0;", + &parsed) + .ok()); + EXPECT_FALSE( + Rendezvous::ParseKey(strings::StrCat(key, ";", key), &parsed).ok()); +} + +class LocalRendezvousTest : public ::testing::Test { + public: + LocalRendezvousTest() + : threads_(new thread::ThreadPool(Env::Default(), "test", 16)) { + rendez_ = NewLocalRendezvous(); + } + + ~LocalRendezvousTest() override { + rendez_->Unref(); + delete threads_; + } + + void SchedClosure(std::function fn) { threads_->Schedule(fn); } + + Rendezvous* rendez_; + + private: + thread::ThreadPool* threads_; +}; + +// string -> Tensor +Tensor V(const string& content) { + Tensor tensor(DT_STRING, TensorShape({})); + tensor.scalar()() = content; + return tensor; +} + +// Tensor -> string +string V(const Tensor& tensor) { + CHECK_EQ(tensor.dtype(), DT_STRING); + CHECK(TensorShapeUtils::IsScalar(tensor.shape())); + return tensor.scalar()(); +} + +TEST_F(LocalRendezvousTest, SendRecv) { + Rendezvous::Args args; + ASSERT_OK(rendez_->Send("foo", args, V("hello"), false)); + EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, V("hello"), false))); + Tensor val(DT_STRING); + bool is_dead = false; + ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead)); + EXPECT_EQ("hello", V(val)); +} + +TEST_F(LocalRendezvousTest, RecvSend) { + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(10000); + Rendezvous::Args args; + ASSERT_OK(rendez_->Send("foo", args, V("hello"), false)); + }); + Tensor val(DT_STRING); + bool is_dead = false; + Rendezvous::Args args; + ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead)); + EXPECT_EQ("hello", V(val)); +} + +TEST_F(LocalRendezvousTest, DuplicateWaiterRecv) { + SchedClosure([this]() { + Tensor t(DT_STRING); + bool is_dead = false; + Rendezvous::Args args; + ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead)); + ASSERT_OK(rendez_->Send("bar", args, t, is_dead)); + }); + Env::Default()->SleepForMicroseconds(1000000); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead))); + ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead)); + ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead)); + EXPECT_EQ("secret msg", V(val)); +} + +TEST_F(LocalRendezvousTest, DuplicateSerialRecv) { + SchedClosure([this]() { + Tensor t(DT_STRING); + bool is_dead = false; + Rendezvous::Args args; + ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead)); + ASSERT_OK(rendez_->Send("bar", args, t, is_dead)); + }); + Env::Default()->SleepForMicroseconds(1000000); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead)); + ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead)); + EXPECT_EQ("secret msg", V(val)); + EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead))); +} + +// A simple structure that behaves a bit like a blocking counter. The +// user that decrements counter to 0 does done.Notify(), and the main +// thread waits for done to be notified. +struct BlockingState { + mutex lock; + int counter; + Notification done; +}; + +TEST_F(LocalRendezvousTest, RandomSendRecv) { + static const int N = 1000; + BlockingState state; + state.counter = N; + for (int i = 0; i < N; ++i) { + SchedClosure([this, i]() { + random::PhiloxRandom philox(testing::RandomSeed() + i, 17); + random::SimplePhilox rnd(&philox); + Env::Default()->SleepForMicroseconds(1000 + rnd.Uniform(10000)); + Rendezvous::Args args; + ASSERT_OK(rendez_->Send(strings::StrCat(i), args, V(strings::StrCat(i)), + false)); + }); + SchedClosure([this, &state, i]() { + random::PhiloxRandom philox(testing::RandomSeed() + N + i, 17); + random::SimplePhilox rnd(&philox); + Env::Default()->SleepForMicroseconds(1000 + rnd.Uniform(10000)); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + ASSERT_OK(rendez_->Recv(strings::StrCat(i), args, &val, &val_dead)); + EXPECT_EQ(strings::StrCat(i), V(val)); + bool done = false; + { + mutex_lock l(state.lock); + state.counter--; + if (state.counter == 0) { + done = true; + } + } + if (done) { + state.done.Notify(); + } + }); + } + + state.done.WaitForNotification(); +} + +TEST_F(LocalRendezvousTest, RecvAbort) { + rendez_->Ref(); + SchedClosure([this]() { + rendez_->StartAbort(errors::Aborted("")); // abort + rendez_->Unref(); + }); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + Status status = rendez_->Recv("foo", args, &val, &val_dead); + EXPECT_TRUE(errors::IsAborted(status)); +} + +// Similar to RecvAbort. But this test case ensures the main thread +// Recv() call happens after StartAbort(). +TEST_F(LocalRendezvousTest, RecvSleepAbort) { + rendez_->Ref(); + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(1000000); + rendez_->StartAbort(errors::Aborted("")); // abort + rendez_->Unref(); + }); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + Status status = rendez_->Recv("foo", args, &val, &val_dead); + EXPECT_TRUE(errors::IsAborted(status)); +} + +TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) { + rendez_->StartAbort(errors::Aborted("")); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, val, val_dead))); + EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead))); +} + +class DummyDeviceContext : public DeviceContext { + public: + explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {} + ~DummyDeviceContext() override {} + int stream_id() const { return stream_id_; } + + private: + const int stream_id_; +}; + +TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) { + Rendezvous::Args args; + args.device_context = new DummyDeviceContext(123); + + ASSERT_OK(rendez_->Send("foo", args, V("hello"), false)); + + Notification n; + Rendezvous::Args args1; + args1.device_context = new DummyDeviceContext(1); + rendez_->RecvAsync("foo", args1, [&n](const Status& s, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& val, bool is_dead) { + CHECK_EQ(123, + dynamic_cast(send_args.device_context) + ->stream_id()); + n.Notify(); + }); + + n.WaitForNotification(); + args.device_context->Unref(); + args1.device_context->Unref(); +} + +static void BM_SendRecv(int iters) { + Rendezvous* rendez = NewLocalRendezvous(); + Tensor orig = V("val"); + Tensor val(DT_STRING, TensorShape({})); + bool is_dead = false; + Rendezvous::Args args; + Status s; + if (iters > 0) { + while (iters--) { + s = rendez->Send("foo", args, orig, is_dead); + s = rendez->Recv("foo", args, &val, &is_dead); + } + CHECK_EQ(V(val), V(orig)); + } + rendez->Unref(); +} +BENCHMARK(BM_SendRecv); + +static void BM_RecvSend(int iters) { + thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1); + + // The main thread sends "foo" for iters/2 times and receives "bar" + // for iters/2 times. The other thread sends "bar" for iters/2 + // times and receives "foo" for iters/2 times. + Rendezvous* rendez = NewLocalRendezvous(); + pool->Schedule([rendez, iters]() { + Tensor bar = V("bar"); + Tensor foo(DT_STRING, TensorShape({})); + bool is_dead = false; + Rendezvous::Args args; + Status s; + for (int i = 0; i < iters / 2; ++i) { + s = rendez->Recv("foo", args, &foo, &is_dead); + s = rendez->Send("bar", args, bar, is_dead); + } + CHECK_EQ("foo", V(foo)); + }); + Tensor foo = V("foo"); + Tensor bar(DT_STRING, TensorShape({})); + bool is_dead = false; + Rendezvous::Args args; + Status s; + for (int i = 0; i < iters / 2; ++i) { + s = rendez->Send("foo", args, foo, is_dead); + s = rendez->Recv("bar", args, &bar, &is_dead); + } + CHECK_EQ("bar", V(bar)); + delete pool; +} +BENCHMARK(BM_RecvSend); + +} // namespace tensorflow diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc new file mode 100644 index 0000000000..42326f068e --- /dev/null +++ b/tensorflow/core/framework/resource_mgr.cc @@ -0,0 +1,146 @@ +#include "tensorflow/core/framework/resource_mgr.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +ResourceMgr::ResourceMgr() : default_container_("localhost") {} + +ResourceMgr::ResourceMgr(const string& default_container) + : default_container_(default_container) {} + +ResourceMgr::~ResourceMgr() { Clear(); } + +void ResourceMgr::Clear() { + mutex_lock l(mu_); + for (const auto& p : containers_) { + for (const auto& q : *p.second) { + q.second->Unref(); + } + delete p.second; + } + containers_.clear(); +} + +Status ResourceMgr::DoCreate(const string& container, std::type_index type, + const string& name, ResourceBase* resource) { + { + mutex_lock l(mu_); + Container** b = &containers_[container]; + if (*b == nullptr) { + *b = new Container; + } + if ((*b)->insert({{type, name}, resource}).second) { + return Status::OK(); + } + } + resource->Unref(); + return errors::AlreadyExists("Resource ", container, "/", name, "/", + type.name()); +} + +Status ResourceMgr::DoLookup(const string& container, std::type_index type, + const string& name, + ResourceBase** resource) const { + mutex_lock l(mu_); + const Container* b = gtl::FindPtrOrNull(containers_, container); + if (b == nullptr) { + return errors::NotFound("Container ", container, " does not exist."); + } + auto r = gtl::FindPtrOrNull(*b, {type, name}); + if (r == nullptr) { + return errors::NotFound("Resource ", container, "/", name, "/", type.name(), + " does not exist."); + } + *resource = const_cast(r); + (*resource)->Ref(); + return Status::OK(); +} + +Status ResourceMgr::DoDelete(const string& container, std::type_index type, + const string& name) { + ResourceBase* base = nullptr; + { + mutex_lock l(mu_); + Container* b = gtl::FindPtrOrNull(containers_, container); + if (b == nullptr) { + return errors::NotFound("Container ", container, " does not exist."); + } + auto iter = b->find({type, name}); + if (iter == b->end()) { + return errors::NotFound("Resource ", container, "/", name, "/", + type.name(), " does not exist."); + } + base = iter->second; + b->erase(iter); + } + CHECK(base != nullptr); + base->Unref(); + return Status::OK(); +} + +Status ResourceMgr::Cleanup(const string& container) { + Container* b = nullptr; + { + mutex_lock l(mu_); + auto iter = containers_.find(container); + if (iter == containers_.end()) { + return errors::NotFound("Container ", container, " does not exist."); + } + b = iter->second; + containers_.erase(iter); + } + CHECK(b != nullptr); + for (const auto& p : *b) { + p.second->Unref(); + } + delete b; + return Status::OK(); +} + +Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef, + bool use_node_name_as_default) { + CHECK(rmgr); + rmgr_ = rmgr; + string attr_container; + TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container", &attr_container)); + static RE2 container_re("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"); + if (!attr_container.empty() && + !RE2::FullMatch(attr_container, container_re)) { + return errors::InvalidArgument("container contains invalid characters: ", + attr_container); + } + string attr_shared_name; + TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &attr_shared_name)); + if (!attr_shared_name.empty() && (attr_shared_name[0] == '_')) { + return errors::InvalidArgument("shared_name cannot start with '_':", + attr_shared_name); + } + if (!attr_container.empty()) { + container_ = attr_container; + } else { + container_ = rmgr_->default_container(); + } + if (!attr_shared_name.empty()) { + name_ = attr_shared_name; + } else if (use_node_name_as_default) { + name_ = ndef.name(); + } else { + resource_is_private_to_kernel_ = true; + static std::atomic counter(0); + name_ = strings::StrCat("_", counter.fetch_add(1), "_", ndef.name()); + } + return Status::OK(); +} + +string ContainerInfo::DebugString() const { + return strings::StrCat("[", container(), ",", name(), ",", + resource_is_private_to_kernel() ? "private" : "public", + "]"); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h new file mode 100644 index 0000000000..65e859caf1 --- /dev/null +++ b/tensorflow/core/framework/resource_mgr.h @@ -0,0 +1,280 @@ +#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_ +#define TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// A ResourceMgr instance keeps track of named and typed resources +// grouped into containers. +// +// Each resource must be represented as a sub-class of ResourceBase, +// which is reference counted explicitly. Each named resource is +// registered with ResourceMgr under a named "container" name. At any +// time, there is at most one instance of a resource given the container +// name, the resource type and the resource name. +// +// All resources for a given container can be dropped by one call of +// Cleanup(). +// +// E.g., +// struct MyVar : public ResourceBase { +// mutex mu; +// Tensor val; +// } +// +// ResourceMgr rm; +// +// // Create a var. +// MyVar* my_var = new MyVar; +// my_var.val = Tensor(DT_FLOAT, my_shape); +// my_val.val.flat().setZeros(); // 0 initialized. +// ctx->SetStatus(rm.Create("my_container", "my_name", my_val)); +// +// // += a variable. +// MyVar* my_var = nullptr; +// Status s = rm.Lookup("my_container", "my_name", &my_var); +// if (s.ok()) { +// my_var->val.flat() += grad; +// } +// my_var->Unref(); // Or use ScopedUnref(). +// ctx->SetStatus(s); +class ResourceBase : public core::RefCounted { + public: + // Returns a debug string for *this. + virtual string DebugString() = 0; +}; + +class ResourceMgr { + public: + ResourceMgr(); + explicit ResourceMgr(const string& default_container); + ~ResourceMgr(); + + // Returns the default container name for *this. + const string& default_container() const { return default_container_; } + + // Creates a resource "name" in the "container". The caller transfers + // the ownership of one ref on "resource" to *this + // + // REQUIRES: std::is_base_of + // REQUIRES: resource != nullptr. + template + Status Create(const string& container, const string& name, + T* resource) TF_MUST_USE_RESULT; + + // If "container" has a resource "name", returns it in "*resource" and + // the caller takes the ownership of one ref on "*resource". + // + // REQUIRES: std::is_base_of + // REQUIRES: resource != nullptr + template + Status Lookup(const string& container, const string& name, + T** resource) const TF_MUST_USE_RESULT; + + // If "container" has a resource "name", returns it in + // "*resource". Otherwise, invokes creator() to create the resource. + // The caller takes the ownership of one ref on "*resource". + // + // REQUIRES: std::is_base_of + // REQUIRES: resource != nullptr + template + Status LookupOrCreate(const string& container, const string& name, + T** resource, + std::function creator) TF_MUST_USE_RESULT; + + // Deletes the resource "name" from the "container". + // + // REQUIRES: std::is_base_of + template + Status Delete(const string& container, const string& name) TF_MUST_USE_RESULT; + + // Deletes all resources from the "container" and removes the container. + Status Cleanup(const string& container) TF_MUST_USE_RESULT; + + // Deletes all resources in all containers. + void Clear(); + + private: + typedef std::pair Key; + struct KeyHash { + std::size_t operator()(const Key& k) const { + return Hash64(k.second.data(), k.second.size(), k.first.hash_code()); + } + }; + struct KeyEqual { + bool operator()(const Key& x, const Key& y) const { + return (x.second == y.second) && (x.first == y.first); + } + }; + typedef std::unordered_map Container; + + const string default_container_; + mutable mutex mu_; + std::unordered_map containers_ GUARDED_BY(mu_); + + Status DoCreate(const string& container, std::type_index type, + const string& name, + ResourceBase* resource) TF_MUST_USE_RESULT; + Status DoLookup(const string& container, std::type_index type, + const string& name, + ResourceBase** resource) const TF_MUST_USE_RESULT; + Status DoDelete(const string& container, std::type_index type, + const string& name) TF_MUST_USE_RESULT; + + TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr); +}; + +// Policy helper to decide which container/shared_name to use for a +// stateful kernel that accesses shared resource. +class ContainerInfo { + public: + // Analyze the node attribute of 'ndef' and decides the container and + // resource name the kernel should use for accessing the shared + // resource. + // + // 'ndef' is expected to have node attribute "container" and + // "shared_name". Returns non-OK if they are not provided or they are + // invalid. + // + // The policy is as following: + // * If the attribute "container" is non-empty, it is used as is. + // Otherwise, uses the resource manager's default container. + // * If the attribute "shared_name" is non-empty, it is used as is. + // Otherwise, if "use_node_name_as_default" is true, the kernel's + // node name is used as the resource name. Otherwise, a string + // unique to this process is used. + Status Init(ResourceMgr* rmgr, const NodeDef& ndef, + bool use_node_name_as_default); + Status Init(ResourceMgr* rmgr, const NodeDef& ndef) { + return Init(rmgr, ndef, false); + } + + // The policy decides that the kernel should access the resource in + // resource_manager(), the resource is in the container() and its + // name is name(). If resource_is_private_to_kernel() is true, the + // kernel should delete the resource when the kernel is deleted. + ResourceMgr* resource_manager() const { return rmgr_; } + const string& container() const { return container_; } + const string& name() const { return name_; } + bool resource_is_private_to_kernel() const { + return resource_is_private_to_kernel_; + } + + // Returns a readable string for *this. + string DebugString() const; + + private: + ResourceMgr* rmgr_ = nullptr; + string container_; + string name_; + bool resource_is_private_to_kernel_ = false; +}; + +// Helper for kernels to obtain 'resource' from the +// ctx->resource_manager(). +// +// "input_name" specifies the kernel's ref input which gives a string +// tensor with two elements, which specifies the container and +// resource name. +// +// Returns OK if the resource is found and transfers one ref of +// *resource to the caller. Otherwise, returns an error. +template +Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name, + T** resource); + +// Implementation details below. + +template +void CheckDeriveFromResourceBase() { + static_assert(std::is_base_of::value, + "T must derive from ResourceBase"); +} + +template +Status ResourceMgr::Create(const string& container, const string& name, + T* resource) { + CheckDeriveFromResourceBase(); + CHECK(resource != nullptr); + return DoCreate(container, std::type_index(typeid(T)), name, resource); +} + +template +Status ResourceMgr::Lookup(const string& container, const string& name, + T** resource) const { + CheckDeriveFromResourceBase(); + ResourceBase* found = nullptr; + Status s = DoLookup(container, std::type_index(typeid(T)), name, &found); + if (s.ok()) { + // It's safe to down cast 'found' to T* since + // typeid(T).hash_code() is part of the map key. + *resource = static_cast(found); + } + return s; +} + +template +Status ResourceMgr::LookupOrCreate(const string& container, const string& name, + T** resource, + std::function creator) { + Status s; + *resource = nullptr; + while (*resource == nullptr) { + s = Lookup(container, name, resource); + if (s.ok()) break; + s = creator(resource); + if (!s.ok()) break; + s = Create(container, name, *resource); + if (s.ok()) { + (*resource)->Ref(); + break; + } + // Rare event. Concurrent racy creation. Redo the lookup. + *resource = nullptr; + } + return s; +} + +template +Status ResourceMgr::Delete(const string& container, const string& name) { + CheckDeriveFromResourceBase(); + return DoDelete(container, std::type_index(typeid(T)), name); +} + +template +Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name, + T** resource) { + string container; + string shared_name; + { + mutex* mu; + TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu)); + mutex_lock l(*mu); + Tensor tensor; + TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true)); + if (tensor.NumElements() != 2) { + return errors::InvalidArgument( + "Resource handle must have 2 elements, but had shape: ", + tensor.shape().DebugString()); + } + container = tensor.flat()(0); + shared_name = tensor.flat()(1); + } + return ctx->resource_manager()->Lookup(container, shared_name, resource); +} + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_ diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc new file mode 100644 index 0000000000..9f7ce3dde3 --- /dev/null +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -0,0 +1,173 @@ +#include "tensorflow/core/framework/resource_mgr.h" + +#include +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +class Resource : public ResourceBase { + public: + explicit Resource(const string& label) : label_(label) {} + ~Resource() override {} + + string DebugString() { return strings::StrCat("R/", label_); } + + private: + string label_; +}; + +class Other : public ResourceBase { + public: + explicit Other(const string& label) : label_(label) {} + ~Other() override {} + + string DebugString() { return strings::StrCat("O/", label_); } + + private: + string label_; +}; + +template +string Find(const ResourceMgr& rm, const string& container, + const string& name) { + T* r; + TF_CHECK_OK(rm.Lookup(container, name, &r)); + const string ret = r->DebugString(); + r->Unref(); + return ret; +} + +template +string LookupOrCreate(ResourceMgr* rm, const string& container, + const string& name, const string& label) { + T* r; + TF_CHECK_OK(rm->LookupOrCreate(container, name, &r, [&label](T** ret) { + *ret = new T(label); + return Status::OK(); + })); + const string ret = r->DebugString(); + r->Unref(); + return ret; +} + +static void HasError(const Status& s, const string& substr) { + EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) + << s << ", expected substring " << substr; +} + +template +Status FindErr(const ResourceMgr& rm, const string& container, + const string& name) { + T* r; + Status s = rm.Lookup(container, name, &r); + CHECK(!s.ok()); + return s; +} + +TEST(ResourceMgrTest, Basic) { + ResourceMgr rm; + TF_CHECK_OK(rm.Create("foo", "bar", new Resource("cat"))); + TF_CHECK_OK(rm.Create("foo", "baz", new Resource("dog"))); + TF_CHECK_OK(rm.Create("foo", "bar", new Other("tiger"))); + + // Expected to fail. + HasError(rm.Create("foo", "bar", new Resource("kitty")), + "Already exists: Resource foo/bar"); + + // Expected to be found. + EXPECT_EQ("R/cat", Find(rm, "foo", "bar")); + EXPECT_EQ("R/dog", Find(rm, "foo", "baz")); + EXPECT_EQ("O/tiger", Find(rm, "foo", "bar")); + + // Expected to be not found. + HasError(FindErr(rm, "bar", "foo"), "Not found: Container bar"); + HasError(FindErr(rm, "foo", "xxx"), "Not found: Resource foo/xxx"); + HasError(FindErr(rm, "foo", "baz"), "Not found: Resource foo/baz"); + + // Delete foo/bar/Resource. + TF_CHECK_OK(rm.Delete("foo", "bar")); + HasError(FindErr(rm, "foo", "bar"), "Not found: Resource foo/bar"); + + TF_CHECK_OK(rm.Create("foo", "bar", new Resource("kitty"))); + EXPECT_EQ("R/kitty", Find(rm, "foo", "bar")); + + // Drop the whole container foo. + TF_CHECK_OK(rm.Cleanup("foo")); + HasError(FindErr(rm, "foo", "bar"), "Not found: Container foo"); +} + +TEST(ResourceMgr, CreateOrLookup) { + ResourceMgr rm; + EXPECT_EQ("R/cat", LookupOrCreate(&rm, "foo", "bar", "cat")); + EXPECT_EQ("R/cat", LookupOrCreate(&rm, "foo", "bar", "dog")); + EXPECT_EQ("R/cat", Find(rm, "foo", "bar")); + + EXPECT_EQ("O/tiger", LookupOrCreate(&rm, "foo", "bar", "tiger")); + EXPECT_EQ("O/tiger", LookupOrCreate(&rm, "foo", "bar", "lion")); + TF_CHECK_OK(rm.Delete("foo", "bar")); + HasError(FindErr(rm, "foo", "bar"), "Not found: Resource foo/bar"); +} + +Status ComputePolicy(const string& attr_container, + const string& attr_shared_name, + bool use_node_name_as_default, string* result) { + ContainerInfo cinfo; + ResourceMgr rmgr; + NodeDef ndef; + ndef.set_name("foo"); + if (attr_container != "none") { + AddNodeAttr("container", attr_container, &ndef); + } + if (attr_shared_name != "none") { + AddNodeAttr("shared_name", attr_shared_name, &ndef); + } + TF_RETURN_IF_ERROR(cinfo.Init(&rmgr, ndef, use_node_name_as_default)); + *result = cinfo.DebugString(); + return Status::OK(); +} + +string Policy(const string& attr_container, const string& attr_shared_name, + bool use_node_name_as_default) { + string ret; + TF_CHECK_OK(ComputePolicy(attr_container, attr_shared_name, + use_node_name_as_default, &ret)); + return ret; +} + +TEST(ContainerInfo, Basic) { + // Correct cases. + EXPECT_EQ(Policy("", "", false), "[localhost,_0_foo,private]"); + EXPECT_EQ(Policy("", "", true), "[localhost,foo,public]"); + EXPECT_EQ(Policy("", "bar", false), "[localhost,bar,public]"); + EXPECT_EQ(Policy("", "bar", true), "[localhost,bar,public]"); + EXPECT_EQ(Policy("cat", "", false), "[cat,_1_foo,private]"); + EXPECT_EQ(Policy("cat", "", true), "[cat,foo,public]"); + EXPECT_EQ(Policy("cat", "bar", false), "[cat,bar,public]"); + EXPECT_EQ(Policy("cat", "bar", true), "[cat,bar,public]"); +} + +Status WrongPolicy(const string& attr_container, const string& attr_shared_name, + bool use_node_name_as_default) { + string dbg; + auto s = ComputePolicy(attr_container, attr_shared_name, + use_node_name_as_default, &dbg); + CHECK(!s.ok()); + return s; +} + +TEST(ContainerInfo, Error) { + // Missing attribute. + HasError(WrongPolicy("none", "", false), "No attr"); + HasError(WrongPolicy("", "none", false), "No attr"); + HasError(WrongPolicy("none", "none", false), "No attr"); + + // Invalid container. + HasError(WrongPolicy("12$%", "", false), "container contains invalid char"); + + // Invalid shared name. + HasError(WrongPolicy("", "_foo", false), "shared_name cannot start with '_'"); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/step_stats.proto b/tensorflow/core/framework/step_stats.proto new file mode 100644 index 0000000000..78610350ec --- /dev/null +++ b/tensorflow/core/framework/step_stats.proto @@ -0,0 +1,58 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/tensor_description.proto"; + +// TODO(tucker): The next 4 message defs are very similar to +// the *LogEntry messages in profile.proto. They should be +// unified in one place. + +message AllocatorMemoryUsed { + string allocator_name = 1; + int64 total_bytes = 2; + int64 peak_bytes = 3; +} + +enum AllocationType { + AT_NOTUSED = 0; // tensor was not filled in + AT_ALLOCATED = 1; // tensor was allocated by the Op + AT_EXISTING = 2; // tensor was set to share the value of an existing tensor + AT_REF = 3; // tensor was set to be a reference to an existing tensor +} + +// Output sizes recorded for a single execution of a graph node. +message NodeOutput { + int32 slot = 1; + // Was the tensor allocated by this Op or a previous computation + AllocationType allocation_type = 2; + TensorDescription tensor_description = 3; +}; + +// Time/size stats recorded for a single execution of a graph node. +message NodeExecStats { + // TODO(tucker): Use some more compact form of node identity than + // the full string name. Either all processes should agree on a + // global id (cost_id?) for each node, or we should use a hash of + // the name. + string node_name = 1; + int64 all_start_micros = 2; + int64 op_start_rel_micros = 3; + int64 op_end_rel_micros = 4; + int64 all_end_rel_micros = 5; + repeated AllocatorMemoryUsed memory = 6; + repeated NodeOutput output = 7; + string timeline_label = 8; + int64 scheduled_micros = 9; + uint32 thread_id = 10; +}; + +message DeviceStepStats { + string device = 1; + repeated NodeExecStats node_stats = 2; +} + +message StepStats { + repeated DeviceStepStats dev_stats = 1; +}; diff --git a/tensorflow/core/framework/summary.proto b/tensorflow/core/framework/summary.proto new file mode 100644 index 0000000000..0e6e659f2f --- /dev/null +++ b/tensorflow/core/framework/summary.proto @@ -0,0 +1,67 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +// Serialization format for histogram module in +// core/lib/histogram/histogram.h +message HistogramProto { + double min = 1; + double max = 2; + double num = 3; + double sum = 4; + double sum_squares = 5; + + // Parallel arrays encoding the bucket boundaries and the bucket values. + // bucket(i) is the count for the bucket i. The range for + // a bucket is: + // i == 0: -DBL_MAX .. bucket_limit(0) + // i != 0: bucket_limit(i-1) .. bucket_limit(i) + repeated double bucket_limit = 6 [packed = true]; + repeated double bucket = 7 [packed = true]; +}; + +// A Summary is a set of named values to be displayed by the +// visualizer. +// +// Summaries are produced regularly during training, as controlled by +// the "summary_interval_secs" attribute of the training operation. +// Summaries are also produced at the end of an evaluation. +message Summary { + message Image { + // Dimensions of the image. + int32 height = 1; + int32 width = 2; + // Valid colorspace values are + // 1 - grayscale + // 2 - grayscale + alpha + // 3 - RGB + // 4 - RGBA + // 5 - DIGITAL_YUV + // 6 - BGRA + int32 colorspace = 3; + // Image data in encoded format. All image formats supported by + // image_codec::CoderUtil can be stored here. + bytes encoded_image_string = 4; + } + + message Value { + // Tag name for the data. Will be used as the title of the graph + // in the visualizer. + // + // Tag is usually "op_name:value_name", where "op_name" itself can have + // structure to indicate grouping. + string tag = 1; + + // Value associated with the tag. + oneof value { + float simple_value = 2; + bytes obsolete_old_style_histogram = 3; + Image image = 4; + HistogramProto histo = 5; + } + } + + // Set of values for the summary. + repeated Value value = 1; +} diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc new file mode 100644 index 0000000000..4a1b65db97 --- /dev/null +++ b/tensorflow/core/framework/tensor.cc @@ -0,0 +1,570 @@ +// Implementation notes: +// +// Tensor.cc uses a few templated classes and structs to facilitate +// implementation of the Tensor class. +// +// * Buffer: provides the implementation for a typed array T[n]. +// The array is allocated by the given allocator. It runs T's +// default constructors and destructors when T is not a simple type +// (e.g., string.), and skips them otherwise. +// +// * Helper: provides various routines given type T. The routines +// includes running the constructor and destructor of T[], encoding +// an decoding T[] into/from a Cord, etc. + +#include "tensorflow/core/public/tensor.h" + +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/tensor_coding.h" + +namespace tensorflow { +namespace { + +// Typed ref-counted buffer: T[n]. +template +class Buffer : public TensorBuffer { + public: + Buffer(Allocator* a, int64 n); + + void* data() const override { return data_; } + size_t size() const override { return sizeof(T) * elem_; } + TensorBuffer* root_buffer() override { return this; } + void FillAllocationDescription(AllocationDescription* proto) const override { + int64 rb = size(); + proto->set_requested_bytes(rb); + proto->set_allocator_name(alloc_->Name()); + if (alloc_->TracksAllocationSizes()) { + int64 ab = alloc_->AllocatedSize(data_); + proto->set_allocated_bytes(ab); + } + } + + private: + Allocator* alloc_; + T* data_; + int64 elem_; + + ~Buffer() override; + + TF_DISALLOW_COPY_AND_ASSIGN(Buffer); +}; + +// is_simple::value if T[] can be safely constructed and destructed +// without running T() and ~T(). We do not use std::is_trivial +// directly because std::complex is not trival but its array +// can be constructed and destructed without running its default ctor +// and dtor. +template +struct is_simple { + static const bool value = std::is_trivial::value || + std::is_same::value || + is_quantized::value; +}; + +template <> +struct is_simple { + static const bool value = true; +}; + +// A set of helper functions depending on T. +template +struct Helper { + // By default, we assume T is a simple type (float, int32, etc.) + static_assert(is_simple::value, "T is not a simple type."); + typedef protobuf::RepeatedField RepeatedFieldType; + + // No constructor to run. + static void RunCtor(T* p, int n) {} + + // No destructor to run. + static void RunDtor(T* p, int n) {} + + // Encoder of simple type T to a string. We do a copy. + template + static void Encode(TensorBuffer* in, int64 n, Destination* out) { + DCHECK_EQ(in->size(), sizeof(T) * n); + port::AssignRefCounted(StringPiece(in->base(), in->size()), in, + out); + } + + // Decoder of simple type T. Copy the bytes from "in" into the + // tensor buffer. + template + static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) { + if (in.size() != sizeof(T) * n) { + LOG(ERROR) << "Input size was " << in.size() << " and expected " + << sizeof(T) * n; + return nullptr; + } + Buffer* buf = new Buffer(a, n); + port::CopyToArray(in, buf->template base()); + return buf; + } + + // Memory usage. + static int64 TotalBytes(TensorBuffer* in, int64 n) { + DCHECK_EQ(in->size(), sizeof(T) * n); + return in->size(); + } +}; + +// Helper specialization for string (the only non-simple type we +// support). +template <> +struct Helper { + // Proto message uses RepeatedFieldType to hold repeated T. + typedef protobuf::RepeatedPtrField RepeatedFieldType; + + // Runs string's default constructor for p[0], p[1], ..., p[n-1]. + static void RunCtor(string* p, int n) { + for (int i = 0; i < n; ++p, ++i) new (p) string(); + } + + // Runs T's default destructor for p[0], p[1], ..., p[n-1]. + static void RunDtor(string* p, int n) { + for (int i = 0; i < n; ++p, ++i) p->~string(); + } + + // Encodes "n" elements of type string stored in "in" into Cord + // "out", which is usually the TensorProto::tensor_content. + template + static void Encode(TensorBuffer* in, int64 n, Destination* out) { + port::EncodeStringList(in->base(), n, out); + } + + // Decodes "n" elements of type string from "in" and constructs a + // buffer out of it. Returns nullptr if the decoding fails. "in" is + // usually the TensorProto::tensor_content. + template + static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) { + Buffer* buf = new Buffer(a, n); + string* strings = buf->template base(); + if (port::DecodeStringList(in, strings, n)) { + return buf; + } else { + buf->Unref(); + return nullptr; + } + } + + // Returns the estimated memory usage of "n" elements of type T + // stored in buffer "in". + static int64 TotalBytes(TensorBuffer* in, int n) { + int64 tot = in->size(); + DCHECK_EQ(tot, sizeof(string) * n); + const string* p = in->base(); + for (int i = 0; i < n; ++i, ++p) tot += p->size(); + return tot; + } +}; + +template +struct ProtoHelper {}; + +// For a C++ type "T" (float, double, int32, etc.), the repeated field +// "N"_val (float_val, int_val, label_val, etc.) of type "F" (float, +// int32, string, etc) in the TensorProto is used for serializing the +// tensor of type "T". +#define PROTO_TRAITS(T, F, N) \ + template <> \ + struct ProtoHelper { \ + typedef Helper::RepeatedFieldType FieldType; \ + static FieldType::const_iterator Begin(const TensorProto& proto) { \ + return proto.N##_val().begin(); \ + } \ + static size_t NumElements(const TensorProto& proto) { \ + return proto.N##_val().size(); \ + } \ + static void Fill(const T* data, size_t n, TensorProto* proto) { \ + typename ProtoHelper::FieldType copy(data, data + n); \ + proto->mutable_##N##_val()->Swap(©); \ + } \ + }; +PROTO_TRAITS(float, float, float); +PROTO_TRAITS(double, double, double); +PROTO_TRAITS(int32, int32, int); +PROTO_TRAITS(uint8, int32, int); +PROTO_TRAITS(int16, int32, int); +PROTO_TRAITS(int8, int32, int); +PROTO_TRAITS(int64, int64, int64); +PROTO_TRAITS(bool, bool, bool); +PROTO_TRAITS(string, string, string); +PROTO_TRAITS(qint8, int32, int); +PROTO_TRAITS(quint8, int32, int); +#undef PROTO_TRAITS + +template <> +struct ProtoHelper { + typedef Helper::RepeatedFieldType FieldType; + static const complex64* Begin(const TensorProto& proto) { + return reinterpret_cast(proto.scomplex_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.scomplex_val().size() / 2; + } + static void Fill(const complex64* data, size_t n, TensorProto* proto) { + const float* p = reinterpret_cast(data); + FieldType copy(p, p + n * 2); + proto->mutable_scomplex_val()->Swap(©); + } +}; + +template <> +struct ProtoHelper { + typedef Helper::RepeatedFieldType FieldType; + static const qint32* Begin(const TensorProto& proto) { + return reinterpret_cast(proto.int_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.int_val().size(); + } + static void Fill(const qint32* data, size_t n, TensorProto* proto) { + const int32* p = reinterpret_cast(data); + FieldType copy(p, p + n); + proto->mutable_int_val()->Swap(©); + } +}; + +template <> +struct ProtoHelper { + typedef Helper::RepeatedFieldType FieldType; + static const bfloat16* Begin(const TensorProto& proto) { + return reinterpret_cast(proto.int_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.int_val().size(); + } + static void Fill(const bfloat16* data, size_t n, TensorProto* proto) { + proto->mutable_int_val()->Reserve(n); + for (size_t i = 0; i < n; ++i) { + proto->mutable_int_val()->AddAlreadyReserved(data[i].value); + } + } +}; + +template +Buffer::Buffer(Allocator* a, int64 n) + : alloc_(a), data_(a->Allocate(n)), elem_(n) { + if (data_) Helper::RunCtor(data_, elem_); +} + +template +Buffer::~Buffer() { + if (data_) { + Helper::RunDtor(data_, elem_); + alloc_->Deallocate(data_); + } +} + +// Allocates a T[n] buffer. Fills in the buffer with repeated values +// in "in". If "in" has less values than "n", fills the rest of T[n] +// with the last value. If "in" has no values, fills T[n] with the +// default value for T. +// +// This routine is using the typed fields (float_val, etc.) in the +// tenor proto as opposed to the untyped binary representation +// (tensor_content). This is used when we expect the TensorProto is +// used by a client program which may not know how to encode a tensor +// in the compact binary representation. +template +TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) { + CHECK_GT(n, 0); + Buffer* buf = new Buffer(a, n); + T* data = buf->template base(); + const int64 in_n = ProtoHelper::NumElements(in); + auto begin = ProtoHelper::Begin(in); + if (n <= in_n) { + std::copy_n(begin, n, data); + } else if (in_n > 0) { + std::copy_n(begin, in_n, data); + const T& last = *(data + in_n - 1); + std::fill_n(data + in_n, n - in_n, last); + } else { + std::fill_n(data, n, T()); + } + return buf; +} + +// Copies T[n] stored in the buffer "in" into the repeated field in +// "out" corresponding to type T. +template +void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) { + const T* data = in.base(); + // NOTE: T may not the same as + // ProtoHelper::FieldType::value_type. E.g., T==int16, + // ProtoHelper::FieldType::value_type==int32. If performance is + // critical, we can specialize T=float and do memcpy directly. + ProtoHelper::Fill(data, n, out); +} + +void RefIfNonNull(core::RefCounted* buf) { + if (buf) buf->Ref(); +} + +void UnrefIfNonNull(core::RefCounted* buf) { + if (buf) buf->Unref(); +} + +} // end namespace + +Tensor::Tensor() : Tensor(DT_FLOAT) {} + +Tensor::Tensor(DataType type) : type_(type), shape_({0}), buf_(nullptr) {} + +Tensor::Tensor(const Tensor& other) + : type_(other.dtype()), shape_(other.shape()), buf_(other.buf_) { + RefIfNonNull(buf_); +} + +Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf) + : type_(type), shape_(shape), buf_(buf) { + RefIfNonNull(buf); +} + +bool Tensor::IsInitialized() const { + return buf_ != nullptr && buf_->data() != nullptr; +} + +Tensor::~Tensor() { UnrefIfNonNull(buf_); } + +void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) { + CHECK_EQ(shape.num_elements(), other.NumElements()); + type_ = other.dtype(); + shape_ = shape; + if (buf_ != other.buf_) { + UnrefIfNonNull(buf_); + buf_ = other.buf_; + RefIfNonNull(buf_); + } +} + +// The macro CASES() expands to a switch statement conditioned on +// TYPE_ENUM. Each case expands the STMTS after a typedef for T. +#define SINGLE_ARG(...) __VA_ARGS__ +#define CASE(TYPE, STMTS) \ + case DataTypeToEnum::value: { \ + typedef TYPE T; \ + STMTS; \ + break; \ + } +#define CASES(TYPE_ENUM, STMTS) \ + switch (TYPE_ENUM) { \ + CASE(float, SINGLE_ARG(STMTS)) \ + CASE(double, SINGLE_ARG(STMTS)) \ + CASE(int32, SINGLE_ARG(STMTS)) \ + CASE(uint8, SINGLE_ARG(STMTS)) \ + CASE(int16, SINGLE_ARG(STMTS)) \ + CASE(int8, SINGLE_ARG(STMTS)) \ + CASE(string, SINGLE_ARG(STMTS)) \ + CASE(complex64, SINGLE_ARG(STMTS)) \ + CASE(int64, SINGLE_ARG(STMTS)) \ + CASE(bool, SINGLE_ARG(STMTS)) \ + CASE(qint32, SINGLE_ARG(STMTS)) \ + CASE(quint8, SINGLE_ARG(STMTS)) \ + CASE(qint8, SINGLE_ARG(STMTS)) \ + CASE(bfloat16, SINGLE_ARG(STMTS)) \ + case DT_INVALID: \ + LOG(FATAL) << "Type not set"; \ + break; \ + default: \ + LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \ + break; \ + } + +Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape) + : type_(type), shape_(shape), buf_(nullptr) { + CHECK_NOTNULL(a); + if (shape_.num_elements() > 0) { + CASES(type, buf_ = new Buffer(a, shape.num_elements())); + } +} + +Tensor::Tensor(DataType type, const TensorShape& shape) + : Tensor(cpu_allocator(), type, shape) {} + +template +class SubBuffer : public TensorBuffer { + public: + // This buffer is an alias to buf[delta, delta + n). + SubBuffer(TensorBuffer* buf, int64 delta, int64 n) + : root_(buf->root_buffer()), data_(buf->base() + delta), elem_(n) { + // Sanity check. The caller should ensure the sub buffer is valid. + CHECK_LE(root_->base(), this->base()); + T* root_limit = root_->base() + root_->size() / sizeof(T); + CHECK_LE(this->base(), root_limit); + CHECK_LE(this->base() + n, root_limit); + // Hold a ref of the underlying root buffer. + // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer. + root_->Ref(); + } + + void* data() const override { return data_; } + size_t size() const override { return sizeof(T) * elem_; } + TensorBuffer* root_buffer() override { return root_; } + void FillAllocationDescription(AllocationDescription* proto) const override { + root_->FillAllocationDescription(proto); + } + + private: + TensorBuffer* root_; + T* data_; + int64 elem_; + + ~SubBuffer() override { root_->Unref(); } + + TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer); +}; + +Tensor Tensor::Slice(int64 start, int64 limit) const { + CHECK_GE(dims(), 1); + CHECK_LE(0, start); + CHECK_LE(start, limit); + int64 dim0_size = shape_.dim_size(0); + CHECK_LE(limit, dim0_size); + if ((start == 0) && (limit == dim0_size)) { + return *this; + } + Tensor ret; + ret.type_ = type_; + ret.shape_ = shape_; + ret.buf_ = nullptr; + if (dim0_size > 0) { + const int64 elems_per_dim0 = NumElements() / dim0_size; + const int64 delta = start * elems_per_dim0; + dim0_size = limit - start; + ret.shape_.set_dim(0, dim0_size); + const int64 num_elems = dim0_size * elems_per_dim0; + if (buf_) { + CASES(type_, ret.buf_ = new SubBuffer(buf_, delta, num_elems)); + } + } + return ret; +} + +bool Tensor::FromProto(const TensorProto& proto) { + return FromProto(cpu_allocator(), proto); +} + +bool Tensor::FromProto(Allocator* a, const TensorProto& proto) { + CHECK_NOTNULL(a); + TensorBuffer* p = nullptr; + if (!TensorShape::IsValid(proto.tensor_shape())) return false; + if (proto.dtype() == DT_INVALID) return false; + TensorShape shape(proto.tensor_shape()); + const int64 N = shape.num_elements(); + if (N > 0 && proto.dtype()) { + if (!proto.tensor_content().empty()) { + const auto& content = proto.tensor_content(); + CASES(proto.dtype(), p = Helper::Decode(a, content, N)); + } else { + CASES(proto.dtype(), p = FromProtoField(a, proto, N)); + } + if (p == nullptr) return false; + } + type_ = proto.dtype(); + shape_ = shape; + UnrefIfNonNull(buf_); + buf_ = p; + return true; +} + +void Tensor::AsProtoField(TensorProto* proto) const { + proto->Clear(); + proto->set_dtype(dtype()); + shape_.AsProto(proto->mutable_tensor_shape()); + if (buf_) { + CASES(dtype(), ToProtoField(*buf_, shape_.num_elements(), proto)); + } +} + +void Tensor::AsProtoTensorContent(TensorProto* proto) const { + proto->Clear(); + proto->set_dtype(type_); + shape_.AsProto(proto->mutable_tensor_shape()); + if (buf_) { + CASES(dtype(), Helper::Encode(buf_, shape_.num_elements(), + proto->mutable_tensor_content())); + } +} + +size_t Tensor::TotalBytes() const { + if (shape_.num_elements() == 0) return 0; + CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements(); + CASES(dtype(), return Helper::TotalBytes(buf_, shape_.num_elements())); + return 0; // Makes compiler happy. +} + +bool Tensor::CanUseDMA() const { + CASES(dtype(), return is_simple::value); + return false; // Makes compiler happy. +} + +#undef CASES +#undef CASE + +string Tensor::SummarizeValue(int64 max_entries) const { + string ret; + for (int64 i = 0; i < std::min(max_entries, NumElements()); ++i) { + if (i > 0) strings::StrAppend(&ret, " "); + switch (dtype()) { + case DT_STRING: + strings::StrAppend(&ret, str_util::CEscape(flat()(i))); + break; + case DT_BOOL: + strings::StrAppend(&ret, flat()(i) ? "True" : "False"); + break; + +#define CASE(DT_ENUM) \ + case DT_ENUM: \ + strings::StrAppend(&ret, flat::Type>()(i)); \ + break + + CASE(DT_FLOAT); + CASE(DT_DOUBLE); + CASE(DT_INT32); + CASE(DT_UINT8); + CASE(DT_INT16); + CASE(DT_INT8); + CASE(DT_INT64); + +#undef CASE + default: + // TODO(zhifengc, josh11b): Pretty-print other types (bool, + // complex64, quantized, bfloat16). + strings::StrAppend(&ret, " ?"); + } + } + if (max_entries < NumElements()) strings::StrAppend(&ret, "..."); + + return ret; +} + +StringPiece Tensor::tensor_data() const { + if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors + return StringPiece(static_cast(buf_->data()), TotalBytes()); +} + +string Tensor::DebugString() const { + return strings::StrCat("Tensor"); +} + +void Tensor::FillDescription(TensorDescription* description) const { + description->set_dtype(dtype()); + shape().AsProto(description->mutable_shape()); + buf_->FillAllocationDescription( + description->mutable_allocation_description()); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor.proto b/tensorflow/core/framework/tensor.proto new file mode 100644 index 0000000000..b42694afde --- /dev/null +++ b/tensorflow/core/framework/tensor.proto @@ -0,0 +1,57 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. TODO(mdevin): sort out the 0-rank issues. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized content from TensorBase::Serialize() This representation can be + // used for all tensor types. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; +}; diff --git a/tensorflow/core/framework/tensor_description.proto b/tensorflow/core/framework/tensor_description.proto new file mode 100644 index 0000000000..1fff3ee155 --- /dev/null +++ b/tensorflow/core/framework/tensor_description.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/types.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/allocation_description.proto"; + +message TensorDescription { + // Data type of tensor elements + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto shape = 2; + + // Information about the size and allocator used for the data + AllocationDescription allocation_description = 4; +}; diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc new file mode 100644 index 0000000000..3db2ffaaca --- /dev/null +++ b/tensorflow/core/framework/tensor_shape.cc @@ -0,0 +1,138 @@ +#include "tensorflow/core/public/tensor_shape.h" + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// An upper limit of the total number of elements in a tensor. +static const int64 kMaxElements = (1LL << 40); + +bool TensorShape::IsValid(const TensorShapeProto& proto) { + int64 num_elements = 1; + for (const auto& d : proto.dim()) { + if (d.size() < 0) return false; + num_elements *= d.size(); + if (num_elements > kMaxElements) return false; + } + return true; +} + +TensorShape::TensorShape(const TensorShapeProto& proto) { + dim_sizes_.reserve(proto.dim_size()); + num_elements_ = 1; + for (const auto& d : proto.dim()) { + AddDim(d.size()); + } +} + +TensorShape::TensorShape(gtl::ArraySlice dim_sizes) { + dim_sizes_.reserve(dim_sizes.size()); + num_elements_ = 1; + for (auto s : dim_sizes) { + AddDim(s); + } +} + +TensorShape::TensorShape() : num_elements_(1) {} + +void TensorShape::Clear() { + dim_sizes_.clear(); + num_elements_ = 1; +} + +void TensorShape::AddDim(int64 size) { + CHECK_GE(size, 0); + dim_sizes_.push_back(size); + num_elements_ *= size; + CHECK_LE(0, num_elements_); + CHECK_LE(num_elements_, kMaxElements); +} + +void TensorShape::AppendShape(const TensorShape& shape) { + for (auto d : shape) AddDim(d.size); +} + +void TensorShape::InsertDim(int d, int64 size) { + CHECK_GE(d, 0); + CHECK_LE(d, dims()); + CHECK_GE(size, 0); + dim_sizes_.insert(dim_sizes_.begin() + d, size); + num_elements_ *= size; + CHECK_LE(0, num_elements_); + CHECK_LE(num_elements_, kMaxElements); +} + +void TensorShape::set_dim(int d, int64 size) { + CHECK_GE(d, 0); + CHECK_LT(d, dims()); + CHECK_GE(size, 0); + + // Update the number of elements. num_elements_ is int64. + dim_sizes_[d] = size; + recompute_dims(); +} + +void TensorShape::RemoveDim(int d) { + CHECK_GE(d, 0); + CHECK_LT(d, dims()); + + // Update the number of elements and remove the dimension from the + // sizes. + dim_sizes_.erase(dim_sizes_.begin() + d); + recompute_dims(); +} + +void TensorShape::recompute_dims() { + num_elements_ = 1; + for (auto s : dim_sizes_) { + num_elements_ *= s; + CHECK_LE(0, num_elements_); + CHECK_LE(num_elements_, kMaxElements); + } +} + +bool TensorShape::IsSameSize(const TensorShape& b) const { + if (b.dims() != dims()) return false; + for (int d = 0; d < dims(); d++) { + if (dim_size(d) != b.dim_size(d)) return false; + } + return true; +} + +void TensorShape::AsProto(TensorShapeProto* proto) const { + proto->Clear(); + for (size_t d = 0; d < dim_sizes_.size(); ++d) { + auto* dim = proto->add_dim(); + dim->set_size(dim_sizes_[d]); + } +} + +TensorShapeIter TensorShape::begin() const { return TensorShapeIter(this, 0); } + +TensorShapeIter TensorShape::end() const { + return TensorShapeIter(this, dims()); +} + +string TensorShape::DebugString() const { + TensorShapeProto proto; + AsProto(&proto); + return proto.ShortDebugString(); +} + +string TensorShape::ShortDebugString() const { + return strings::StrCat( + "[", str_util::Join(gtl::ArraySlice(dim_sizes_), ","), "]"); +} + +bool TensorShapeUtils::StartsWith(const TensorShape& shape, + const TensorShape& prefix) { + if (shape.dims() < prefix.dims()) return false; + for (int i = 0; i < prefix.dims(); i++) { + if (shape.dim_size(i) != prefix.dim_size(i)) return false; + } + return true; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_shape.proto b/tensorflow/core/framework/tensor_shape.proto new file mode 100644 index 0000000000..8fe7cce13d --- /dev/null +++ b/tensorflow/core/framework/tensor_shape.proto @@ -0,0 +1,29 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +// option cc_enable_arenas = true; + +package tensorflow; + +// Dimensions of a tensor and the type of data it contains. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} for a 30 x + // 40 2D tensor. The names are optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + repeated Dim dim = 2; +}; diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc new file mode 100644 index 0000000000..adac1a4787 --- /dev/null +++ b/tensorflow/core/framework/tensor_shape_test.cc @@ -0,0 +1,75 @@ +#include "tensorflow/core/public/tensor_shape.h" + +#include + +namespace tensorflow { +namespace { + +TEST(TensorShapeTest, Default) { + // The default TensorShape constructor constructs a shape of 0-dim + // and 1-element. + TensorShape s; + EXPECT_EQ(s.dims(), 0); + EXPECT_EQ(s.num_elements(), 1); +} + +TEST(TensorShapeTest, set_dim) { + TensorShape s({10, 5}); + + s.set_dim(0, 20); + ASSERT_EQ(2, s.dims()); + EXPECT_EQ(20, s.dim_size(0)); + EXPECT_EQ(100, s.num_elements()); + + s.set_dim(1, 2); + ASSERT_EQ(2, s.dims()); + EXPECT_EQ(2, s.dim_size(1)); + EXPECT_EQ(40, s.num_elements()); +} + +TEST(TensorShapeTest, RemoveDim) { + TensorShape s({10, 5}); + s.RemoveDim(0); + EXPECT_EQ(5, s.num_elements()); + ASSERT_EQ(1, s.dims()); +} + +TEST(TensorShapeTest, RemoveAndAddDim) { + TensorShape s({10, 5, 20}); + s.RemoveDim(1); + s.AddDim(100); + + EXPECT_EQ(20000, s.num_elements()); + ASSERT_EQ(3, s.dims()); +} + +TEST(TensorShapeTest, InvalidShapeProto) { + TensorShapeProto proto; + EXPECT_TRUE(TensorShape::IsValid(proto)); + + proto.add_dim()->set_size(357); + proto.add_dim()->set_size(982); + EXPECT_TRUE(TensorShape::IsValid(proto)); + + proto.Clear(); + proto.add_dim()->set_size(-357); + proto.add_dim()->set_size(-982); + EXPECT_FALSE(TensorShape::IsValid(proto)); + + proto.Clear(); + proto.add_dim()->set_size(1LL << 20); + proto.add_dim()->set_size((1LL << 20) + 1); + EXPECT_FALSE(TensorShape::IsValid(proto)); +} + +TEST(TensorShapeTest, SetDimForEmptyTensor) { + TensorShape s({10, 5, 20}); + EXPECT_EQ(1000, s.num_elements()); + s.set_dim(1, 0); + EXPECT_EQ(0, s.num_elements()); + s.set_dim(1, 7); + EXPECT_EQ(1400, s.num_elements()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_slice.cc b/tensorflow/core/framework/tensor_slice.cc new file mode 100644 index 0000000000..473d9463ee --- /dev/null +++ b/tensorflow/core/framework/tensor_slice.cc @@ -0,0 +1,226 @@ +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); } + +TensorSlice::TensorSlice(const TensorSliceProto& proto) { + starts_.reserve(proto.extent_size()); + lengths_.reserve(proto.extent_size()); + for (const auto& e : proto.extent()) { + starts_.push_back(e.start()); + lengths_.push_back(GetExtentLength(e)); + } +} + +TensorSlice::TensorSlice(std::initializer_list> extents) { + starts_.reserve(extents.size()); + lengths_.reserve(extents.size()); + for (const auto& e : extents) { + starts_.push_back(e.first); + lengths_.push_back(e.second); + } +} + +Status TensorSlice::Parse(const string& str, TensorSlice* slice) { + std::vector items = str_util::Split(str, ':', str_util::SkipEmpty()); + slice->starts_.reserve(items.size()); + slice->lengths_.reserve(items.size()); + for (const string& x : items) { + int s, l; + if (x == "-") { + // "everything" + s = 0; + l = kFullExtent; + } else { + char junk; + if (sscanf(x.c_str(), "%d,%d%c", &s, &l, &junk) != 2) { + return errors::InvalidArgument( + "Expected a pair of numbers or '-' " + "but got '", + x, "': string = ", str); + } + if (s < 0 || l <= 0) { + return errors::InvalidArgument( + "Expected non-negative start and " + "positive length but got start = ", + s, ", length = ", l, ": string = ", str); + } + } + slice->starts_.push_back(s); + slice->lengths_.push_back(l); + } + + return Status::OK(); +} + +void TensorSlice::Clear() { + starts_.clear(); + lengths_.clear(); +} + +void TensorSlice::SetFullSlice(int dim) { + Clear(); + starts_.reserve(dim); + lengths_.reserve(dim); + for (int d = 0; d < dim; ++d) { + starts_.push_back(0); + lengths_.push_back(kFullExtent); + } +} + +void TensorSlice::Extend(int dim) { + int old_dim = dims(); + DCHECK_LE(old_dim, dim); + starts_.resize(dim); + lengths_.resize(dim); + for (int d = old_dim; d < dim; ++d) { + starts_[d] = 0; + lengths_[d] = kFullExtent; + } +} + +void TensorSlice::AsProto(TensorSliceProto* proto) const { + for (int d = 0; d < dims(); ++d) { + TensorSliceProto::Extent* e = proto->add_extent(); + // We only need to record the explicit slice for non-full slices + if (!IsFullAt(d)) { + e->set_start(starts_[d]); + e->set_length(lengths_[d]); + } + } +} + +string TensorSlice::DebugString() const { + string buffer; + bool first = true; + for (int d = 0; d < dims(); ++d) { + if (!first) { + buffer.append(":"); + } + string s; + if (IsFullAt(d)) { + buffer.append("-"); + } else { + strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]); + } + first = false; + } + return buffer; +} + +bool TensorSlice::Intersect(const TensorSlice& other, + TensorSlice* result) const { + // First, if two slices have different ranks, they obviously don't overlap + // -- in fact they are not compatible. + if (dims() != other.dims()) { + return false; + } + + // Setting the result to the right dimension + if (result) { + result->SetFullSlice(dims()); + } + // The two slices overlap if they overlap in all dimensions. + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + if (result) { + result->set_start(d, other.start(d)); + result->set_length(d, other.length(d)); + } + } else if (other.IsFullAt(d)) { + if (result) { + result->set_start(d, start(d)); + result->set_length(d, length(d)); + } + } else { + // If we have an intersection here, it should have a start that is the + // max of the two starts and an end that is the min of the two ends. + int s = std::max(start(d), other.start(d)); + int l = std::min(end(d), other.end(d)) - s; + if (l > 0) { + // We have a real intersection + if (result) { + result->set_start(d, s); + result->set_length(d, l); + } + } else { + // We don't have an intersection for this dimension -- thus we don't + // have any intersection at all. + if (result) { + result->Clear(); + } + return false; + } + } + } + // If we are here, we know there is overlap in every dimension. + return true; +} + +void TensorSlice::ComputeRelative(const TensorSlice& sub, + TensorSlice* relative) const { + DCHECK_EQ(dims(), sub.dims()); + relative->SetFullSlice(dims()); + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + relative->set_start(d, sub.start(d)); + relative->set_length(d, sub.length(d)); + } else { + // Otherwise the relative start is the difference between the start of + // sub and the start of base + relative->set_start(d, sub.start(d) - start(d)); + relative->set_length(d, sub.length(d)); + } + } +} + +// static +bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) { + return extent.has_length_case() == TensorSliceProto::Extent::kLength; +} + +// static +int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) { + if (!HasExtentLength(extent)) return -1; + return extent.length(); +} + +Status TensorSlice::SliceTensorShape(const TensorShape& shape, + TensorShape* result_shape) const { + result_shape->Clear(); + // Mismatching ranks: we can't apply the slice at all. + if (shape.dims() != dims()) { + return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(), + ", slice = ", DebugString()); + } + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + result_shape->AddDim(shape.dim_size(d)); + } else { + // Check if the extent applies to the dimension + if (end(d) <= shape.dim_size(d)) { + // Yes: the end is within the range of the dim -- we adjust the result + // shape so that its size along this dimension is the length of the + // slice. + result_shape->AddDim(length(d)); + } else { + // The extent doesn't apply to the dimension + result_shape->Clear(); + return errors::Internal("Extent in dimension ", d, + " out of bounds: shape = ", shape.DebugString(), + ", slice = ", DebugString()); + } + } + } + // If we are here, we have successfully applied the shape. + return Status::OK(); +} + +const int TensorSlice::kFullExtent = -1; + +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_slice.h b/tensorflow/core/framework/tensor_slice.h new file mode 100644 index 0000000000..8e2f108c3f --- /dev/null +++ b/tensorflow/core/framework/tensor_slice.h @@ -0,0 +1,189 @@ +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_ + +#include +#include "tensorflow/core/framework/tensor_slice.pb.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// A tensor slice represents a slice of a given tensor. It is represented by a +// list of (start, length) pairs, where the size of the list is the rank of the +// tensor. + +class TensorSlice { + public: + // Construct a tensor slice: you have a number of ways: + // -- creating an empty slice + // -- from just a dimension (in this case it will create a full slice) + // -- from an array of pairs of integers. + // -- from a TensorSliceProto protocol buffer + // -- from a string format of "start,lenth:start,length..." where each + // "start,length" pair represents the slice on one dimension. We allow a + // special "-" that means "everything for this dimension". One such example + // is: 0,10:-:14,1:-:- + TensorSlice() {} + explicit TensorSlice(int dim); + explicit TensorSlice(const TensorSliceProto& proto); + explicit TensorSlice(std::initializer_list> extents); + + static Status Parse(const string& str, TensorSlice* output); + static TensorSlice ParseOrDie(const string& str) { + TensorSlice ret; + Status s = Parse(str, &ret); + if (!s.ok()) { + LOG(FATAL) << "Could not parse TensorSlice"; + } + return ret; + } + + void Clear(); + + // Accessors + int dims() const { return starts_.size(); } + + int start(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return starts_[d]; + } + + int length(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return lengths_[d]; + } + + int end(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return start(d) + length(d); + } + + void set_start(int d, int x) { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + DCHECK_GE(x, 0); + starts_[d] = x; + } + + void set_length(int d, int x) { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + lengths_[d] = x; + } + + // If we have a full slice along dimension "d". + bool IsFullAt(int d) const { return lengths_[d] < 0; } + + // Set the slice to be a full slice of "dim" dimensions + void SetFullSlice(int dim); + + // Extend a slice to "dim" dimensions: all the added dimensions are full. + // Requires: dim >= dims(). + void Extend(int dim); + + // Conversion of a TensorSlice to other formats + void AsProto(TensorSliceProto* proto) const; + string DebugString() const; + + // Fill *indices and *sizes from *this (so that we can use the slice() + // function in eigen tensor). We need a tensor shape in case some of the + // slices are full slices. + // We allow NDIMS to be greater than dims(), in which case we will pad the + // higher dimensions with trivial dimensions. + template + void FillIndicesAndSizes(const TensorShape& shape, + Eigen::DSizes* indices, + Eigen::DSizes* sizes) const; + + // Interaction with other TensorSlices. + + // Compute the intersection with another slice and if "result" is not + // nullptr, store the results in *result; returns true is there is any real + // intersection. + bool Intersect(const TensorSlice& other, TensorSlice* result) const; + // A short hand. + bool Overlaps(const TensorSlice& other) const { + return Intersect(other, nullptr); + } + + // Interaction with TensorShape. + + // Slices a shape and stores the result into *result_shape. + // Requires that the shape and *this have the same rank. + // For example, given a tensor shape of {3, 4, 5}, and a slice of + // 1,2:-:0,2, the result shape is {2, 4, 2}. + Status SliceTensorShape(const TensorShape& shape, + TensorShape* result_shape) const; + + // Given slice "sub" where "sub" is fully contained in *this, + // (meaning that the intersection of "sub" and *this equals "sub"), computes + // the "relative" slice of "sub" with respect to *this. + // + // In other words, if we use A>S to denote slicing a shape S with a slice A, + // then the function is computing a slice X such that: + // X > (this > S) = sub > S + // for any shape S. + // + // In general, along every dimension, the start of the relative slice is the + // start of the "sub" slice minus the start of *this; the length of the + // relative slice is the length of the "sub" slice. + // + // For example, say we have a shape of {3, 4, 5}, "this" is 0,2:-:1,2, and + // "sub" is 1,1:2:2,1,2, then the related slice is 1,1:2,2:0,2. + // + // The caller needs to make sure that "sub" is indeed a sub-slice of *this; + // otherwise the result is undefined. + void ComputeRelative(const TensorSlice& sub, TensorSlice* relative) const; + + // Returns true if the length field was specified in an Extent. + static bool HasExtentLength(const TensorSliceProto::Extent& extent); + + // Returns the value of the length field in an Extent, or -1 if it + // is not present. + static int64 GetExtentLength(const TensorSliceProto::Extent& extent); + + private: + // a length value of kFullExtent (-1) means we have a full slice at this + // dimension. It's defined in tensor_slice.cc. + static const int kFullExtent; + + // TODO(yangke): switch to Eigen once it supports variable size arrays. + // A value of + gtl::InlinedVector starts_; + gtl::InlinedVector lengths_; +}; + +template +void TensorSlice::FillIndicesAndSizes( + const TensorShape& shape, Eigen::DSizes* indices, + Eigen::DSizes* sizes) const { + CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape " + << "slices: shape = " << shape.DebugString() + << ", slice = " << DebugString(); + CHECK_GE(NDIMS, dims()) << "Asking for a " << NDIMS << "-dim slice from " + << "a slice of dimension " << dims(); + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + (*indices)[d] = 0; + (*sizes)[d] = shape.dim_size(d); + } else { + (*indices)[d] = starts_[d]; + (*sizes)[d] = lengths_[d]; + } + } + for (int d = dims(); d < NDIMS; ++d) { + (*indices)[d] = 0; + (*sizes)[d] = 1; + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_ diff --git a/tensorflow/core/framework/tensor_slice.proto b/tensorflow/core/framework/tensor_slice.proto new file mode 100644 index 0000000000..ca676bc766 --- /dev/null +++ b/tensorflow/core/framework/tensor_slice.proto @@ -0,0 +1,34 @@ +// Protocol buffer representing slices of a tensor + +syntax = "proto3"; +// option cc_enable_arenas = true; + +package tensorflow; + +// Can only be interpreted if you know the corresponding TensorShape. +message TensorSliceProto { + // Extent of the slice in one dimension. + message Extent { + // Either both or no attributes must be set. When no attribute is set + // means: All data in that dimension. + + // Start index of the slice, starting at 0. + int64 start = 1; + + // Length of the slice: if the length is missing or -1 we will + // interpret this as "everything in this dimension". We use + // "oneof" to preserve information about whether the length is + // present without changing the serialization format from the + // prior proto2 version of this proto. + oneof has_length { + int64 length = 2; + } + }; + + // Extent of the slice in all tensor dimensions. + // + // Must have one entry for each of the dimension of the tensor that this + // slice belongs to. The order of sizes is the same as the order of + // dimensions in the TensorShape. + repeated Extent extent = 1; +}; diff --git a/tensorflow/core/framework/tensor_slice_test.cc b/tensorflow/core/framework/tensor_slice_test.cc new file mode 100644 index 0000000000..5f718a56b6 --- /dev/null +++ b/tensorflow/core/framework/tensor_slice_test.cc @@ -0,0 +1,246 @@ +#include "tensorflow/core/framework/tensor_slice.h" + +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/logging.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +// Basic tests +TEST(TensorSliceTest, Basic) { + { + // Repeatedly setting FullSlice should work. + TensorSlice s(3); + EXPECT_EQ("-:-:-", s.DebugString()); + + s.SetFullSlice(4); + EXPECT_EQ("-:-:-:-", s.DebugString()); + } +} + +// Testing for serialization and parsing for the string format of slices. +TEST(TensorSliceTest, Serialization) { + // Serialization + { + TensorSlice s({{0, -1}, {0, 10}, {14, 1}, {0, -1}}); + EXPECT_EQ("-:0,10:14,1:-", s.DebugString()); + } + + { + TensorSliceProto proto; + // Define ptxt outside ASSERT_TRUE call to work around bug in some + // versions of gcc that breaks when you use raw string literals + // inside macro expansions. + // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971 + const char* ptxt = R"PROTO( + extent { } + extent { start: 0 length: 10 } + extent { start: 14 length: 1 } + extent { } + )PROTO"; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto)); + TensorSlice s(proto); + EXPECT_EQ("-:0,10:14,1:-", s.DebugString()); + } + + // Parsing + { + TensorSlice s = TensorSlice::ParseOrDie("-:-:1,3:4,5"); + TensorSliceProto proto; + s.AsProto(&proto); + EXPECT_EQ( + "extent { } " + "extent { } " + "extent { start: 1 length: 3 } " + "extent { start: 4 length: 5 }", + proto.ShortDebugString()); + } + + // Failed parsing + { + TensorSlice slice; + Status s = TensorSlice::Parse("-:-:1,3:4:5", &slice); + EXPECT_EQ( + "Invalid argument: " + "Expected a pair of numbers or '-' but got '4': " + "string = -:-:1,3:4:5", + s.ToString()); + } + { + TensorSlice slice; + Status s = TensorSlice::Parse("-:-1,3", &slice); + EXPECT_EQ( + "Invalid argument: " + "Expected non-negative start and positive length but got " + "start = -1, length = 3: string = -:-1,3", + s.ToString()); + } +} + +// Testing the slice intersection +TEST(TensorSliceTest, Intersection) { + // "EVERYTHING" intersects with everything + { + TensorSlice a = TensorSlice::ParseOrDie("-:-"); + TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4"); + TensorSlice c; + EXPECT_TRUE(a.Intersect(b, &c)); + EXPECT_EQ("1,2:3,4", c.DebugString()); + } + + { + TensorSlice a = TensorSlice::ParseOrDie("-:-"); + TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4"); + TensorSlice c; + EXPECT_TRUE(b.Intersect(a, &c)); + EXPECT_EQ("1,2:3,4", c.DebugString()); + } + + // Overlap at all dimensions + { + TensorSlice a = TensorSlice::ParseOrDie("1,5:2,6:3,7:5,10"); + TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4:9,10:12,1"); + TensorSlice c; + EXPECT_TRUE(a.Intersect(b, &c)); + EXPECT_EQ("1,2:3,4:9,1:12,1", c.DebugString()); + } + + // A mixture of everything and non-trivial slices + { + TensorSlice a = TensorSlice::ParseOrDie("-:1,1"); + TensorSlice b = TensorSlice::ParseOrDie("-:0,2"); + TensorSlice c; + EXPECT_TRUE(a.Intersect(b, &c)); + EXPECT_EQ("-:1,1", c.DebugString()); + } + + // No overlap on dimension 3: "3,1" and "4,5" don't intersect + { + TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:5,6"); + TensorSlice b = TensorSlice::ParseOrDie("1,3:4,5:1,6"); + TensorSlice c; + EXPECT_FALSE(a.Intersect(b, &c)); + EXPECT_EQ("", c.DebugString()); + } + // No intersection when there are different dimensions + { + TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:-"); + TensorSlice b = TensorSlice::ParseOrDie("-:-"); + TensorSlice c; + EXPECT_FALSE(a.Intersect(b, &c)); + EXPECT_EQ("", c.DebugString()); + } +} + +// Testing applying a slice to a tensor shape +TEST(TensorSliceTest, SliceTensorShape) { + // A proper application + { + TensorSlice a = TensorSlice::ParseOrDie("1,1:-:4,1:2,6"); + TensorShape x({2, 4, 5, 8}); + TensorShape y; + EXPECT_OK(a.SliceTensorShape(x, &y)); + EXPECT_EQ( + "dim { size: 1 } " + "dim { size: 4 } " + "dim { size: 1 } " + "dim { size: 6 }", + y.DebugString()); + } + + // An invalid application -- dimension 2 is out of bound + { + TensorSlice a = TensorSlice::ParseOrDie("1,1:1,4:-:-"); + TensorShape x({2, 4, 5, 8}); + TensorShape y; + EXPECT_EQ( + "Internal: " + "Extent in dimension 1 out of bounds: " + "shape = dim { size: 2 } " + "dim { size: 4 } " + "dim { size: 5 } " + "dim { size: 8 }, " + "slice = 1,1:1,4:-:-", + a.SliceTensorShape(x, &y).ToString()); + EXPECT_EQ("", y.DebugString()); + } +} + +// Testing the computation of relative slices. +TEST(TensorSliceTest, ComputeRelative) { + // Easy case: base is "everything" + { + TensorSlice base = TensorSlice::ParseOrDie("-:-:-:-"); + TensorSlice sub = TensorSlice::ParseOrDie("-:1,2:-:3,4"); + TensorSlice relative; + base.ComputeRelative(sub, &relative); + EXPECT_EQ("-:1,2:-:3,4", relative.DebugString()); + } + + // A slightly more complicated case + { + TensorSlice base = TensorSlice::ParseOrDie("1,2:3,4:-:5,1"); + TensorSlice sub = TensorSlice::ParseOrDie("1,1:4,2:3,3:5,1"); + TensorSlice relative; + base.ComputeRelative(sub, &relative); + EXPECT_EQ("0,1:1,2:3,3:0,1", relative.DebugString()); + } +} + +TEST(TensorSliceTest, ExtentLength) { + TensorSliceProto proto; + // Define ptxt outside ASSERT_TRUE call to work around bug in some + // versions of gcc that breaks when you use raw string literals + // inside macro expansions. + // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971 + const char* ptxt = R"PROTO( + extent { } + extent { start: 0 length: 10 } + extent { start: 14 length: 1 } + extent { } + )PROTO"; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto)); + EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(0))); + EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(1))); + EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(2))); + EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(3))); + EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(0))); + EXPECT_EQ(10, TensorSlice::GetExtentLength(proto.extent(1))); + EXPECT_EQ(1, TensorSlice::GetExtentLength(proto.extent(2))); + EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(3))); +} + +TEST(TensorSliceTest, Deserialization) { + // Serialization of + // extent { length: 5 } + // extent { start: 0 length: 10 } + // extent { start: 14 length: 1 } + // extent { start: 1 } + // extent { } + // in proto2 and proto3: + const char pb2[] = + "\x0A\x02\x10\x05\x0A\x04\x08\x00" + "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00"; + const char pb3[] = + "\x0A\x02\x10\x05\x0A\x02" + "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00"; + // (The difference is that in the proto3 version, "start: 0" isn't included + // since 0 is start's default value.) + + TensorSliceProto proto2; + ASSERT_TRUE(proto2.ParseFromArray(pb2, sizeof(pb2) - 1)); + TensorSlice ts2(proto2); + + TensorSliceProto proto3; + ASSERT_TRUE(proto3.ParseFromArray(pb3, sizeof(pb3) - 1)); + TensorSlice ts3(proto3); + + // Both serializations should be interpreted the same. + EXPECT_EQ("0,5:0,10:14,1:-:-", ts2.DebugString()); + EXPECT_EQ("0,5:0,10:14,1:-:-", ts3.DebugString()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc new file mode 100644 index 0000000000..4963c2c219 --- /dev/null +++ b/tensorflow/core/framework/tensor_test.cc @@ -0,0 +1,551 @@ +#include "tensorflow/core/public/tensor.h" + +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include + +namespace tensorflow { + +TEST(TensorTest, Default) { + Tensor t; + EXPECT_EQ(t.dtype(), DT_FLOAT); + EXPECT_EQ(t.dims(), 1); + EXPECT_EQ(t.NumElements(), 0); +} + +TEST(TensorTest, DataType_Traits) { + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_FALSE(std::is_trivial::value); + + EXPECT_EQ(sizeof(bool), 1); + + // Unfortunately. std::complex::complex() initializes (0, 0). + EXPECT_FALSE(std::is_trivial::value); + EXPECT_FALSE(std::is_trivial>::value); + EXPECT_TRUE(std::is_trivial::value); + struct MyComplex { + float re, im; + }; + EXPECT_TRUE(std::is_trivial::value); +} + +template +void TestCopies(const Tensor& t) { + { + LOG(INFO) << "CopyFrom()"; + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.CopyFrom(t, t.shape())); + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "operator=()"; + Tensor t2(t.dtype()); + t2 = t; + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "deep copy"; + Tensor t2(t.dtype(), t.shape()); + t2.flat() = t.flat(); + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "AsProtoField()"; + TensorProto proto; + t.AsProtoField(&proto); + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.FromProto(proto)); + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "AsProtoTensorContent()"; + TensorProto proto; + t.AsProtoTensorContent(&proto); + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.FromProto(proto)); + test::ExpectTensorEqual(t, t2); + // Make another copy via tensor_content field. + *proto.mutable_tensor_content() = proto.tensor_content(); + Tensor t3(t.dtype()); + EXPECT_TRUE(t3.FromProto(proto)); + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "AsTensor"; + gtl::ArraySlice values(t.flat().data(), t.NumElements()); + Tensor t2 = test::AsTensor(values, t.shape()); + test::ExpectTensorEqual(t, t2); + } +} + +TEST(Tensor_Float, Simple) { + Tensor t(DT_FLOAT, TensorShape({10, 20})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({10, 20}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = static_cast(a * b); + } + } + TestCopies(t); +} + +TEST(Tensor_QInt8, Simple) { + Tensor t(DT_QINT8, TensorShape({2, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = qint8(a * b); + } + } + TestCopies(t); +} + +TEST(Tensor_QUInt8, Simple) { + Tensor t(DT_QUINT8, TensorShape({2, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = Eigen::QUInt8(a * b); + } + } + TestCopies(t); +} + +TEST(Tensor_QInt32, Simple) { + Tensor t(DT_QINT32, TensorShape({2, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = qint32(static_cast(a * b)); + } + } + TestCopies(t); +} + +TEST(Tensor_Float, Reshape) { + Tensor t(DT_FLOAT, TensorShape({2, 3, 4, 5})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 3, 4, 5}))); + + { + auto tensor = t.tensor(); + EXPECT_EQ(2, tensor.dimension(0)); + EXPECT_EQ(3, tensor.dimension(1)); + EXPECT_EQ(4, tensor.dimension(2)); + EXPECT_EQ(5, tensor.dimension(3)); + + // Set first and last elements. + tensor(0, 0, 0, 0) = 0.01f; + tensor(1, 2, 3, 4) = 0.02f; + } + { + auto shaped = t.shaped({120}); + EXPECT_EQ(120, shaped.dimension(0)); + EXPECT_EQ(shaped(0), 0.01f); + EXPECT_EQ(shaped(119), 0.02f); + } + { + auto shaped = t.shaped({6, 20}); + EXPECT_EQ(6, shaped.dimension(0)); + EXPECT_EQ(20, shaped.dimension(1)); + EXPECT_EQ(shaped(0, 0), 0.01f); + EXPECT_EQ(shaped(5, 19), 0.02f); + } + { + auto shaped = t.shaped({6, 4, 5}); + EXPECT_EQ(6, shaped.dimension(0)); + EXPECT_EQ(4, shaped.dimension(1)); + EXPECT_EQ(5, shaped.dimension(2)); + EXPECT_EQ(shaped(0, 0, 0), 0.01f); + EXPECT_EQ(shaped(5, 3, 4), 0.02f); + } + { + auto shaped = t.shaped({2, 3, 4, 5}); + EXPECT_EQ(2, shaped.dimension(0)); + EXPECT_EQ(3, shaped.dimension(1)); + EXPECT_EQ(4, shaped.dimension(2)); + EXPECT_EQ(5, shaped.dimension(3)); + + EXPECT_EQ(shaped(0, 0, 0, 0), 0.01f); + EXPECT_EQ(shaped(1, 2, 3, 4), 0.02f); + } + { + auto flat = t.flat(); + EXPECT_EQ(flat(0), 0.01f); + EXPECT_EQ(120, flat.dimension(0)); + EXPECT_EQ(flat(0), 0.01f); + EXPECT_EQ(flat(119), 0.02f); + } + { + auto flat_inner_dims = t.flat_inner_dims(); + EXPECT_EQ(24, flat_inner_dims.dimension(0)); + EXPECT_EQ(5, flat_inner_dims.dimension(1)); + EXPECT_EQ(flat_inner_dims(0, 0), 0.01f); + EXPECT_EQ(flat_inner_dims(23, 4), 0.02f); + } +} + +TEST(Tensor_Scalar, Basics) { + { + Tensor t(DT_FLOAT, TensorShape({})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.scalar()() = 123.45f; + EXPECT_FLOAT_EQ(123.45f, Tt()); + } + { + Tensor t(DT_FLOAT, TensorShape({1})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.vec(); + EXPECT_EQ(1, Tt.size()); + t.vec()(0) = 123.45f; + EXPECT_FLOAT_EQ(123.45f, Tt(0)); + } + { + Tensor t(DT_FLOAT, TensorShape({1, 1, 1})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.flat()(0) = 123.45f; + EXPECT_FLOAT_EQ(123.45f, Tt()); + } + { + Tensor t(DT_STRING, TensorShape({})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.scalar()() = "foo"; + EXPECT_EQ("foo", Tt()); + } + { + Tensor t(DT_STRING, TensorShape({1})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.vec(); + EXPECT_EQ(1, Tt.size()); + t.flat()(0) = "foo"; + EXPECT_EQ("foo", Tt(0)); + } + { + Tensor t(DT_STRING, TensorShape({1, 1, 1})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.flat()(0) = "bar"; + EXPECT_EQ("bar", Tt()); + } + { + Tensor t(DT_FLOAT, TensorShape({0, 1})); + EXPECT_EQ(0, t.NumElements()); + auto Tt = t.flat(); + EXPECT_EQ(0, Tt.size()); + auto Tm = t.matrix(); + EXPECT_EQ(0, Tm.size()); + EXPECT_EQ(0, Tm.dimensions()[0]); + EXPECT_EQ(1, Tm.dimensions()[1]); + } +} + +TEST(Tensor_Float, Reshape_And_Slice_Assignment) { + // A test to experiment with a way to assign to a subset of a tensor + Tensor t(DT_FLOAT, TensorShape({10, 4, 3, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({10, 4, 3, 2}))); + + // Get the N dimensional tensor (N==4 here) + auto e_t = t.tensor(); + // Reshape to view it as a two-dimensional tensor + auto e_2d = t.shaped({10, 4 * 3 * 2}); + for (int i = 0; i < 10; i++) { + // Assign a 1 x 4*3*2 matrix (really vector) to a slice of size + // 1 x 4*3*2 in e_t. + Eigen::Tensor m(1, 4 * 3 * 2); + m.setConstant(i * 2.0); + + Eigen::DSizes indices(i, 0); + Eigen::DSizes sizes(1, 4 * 3 * 2); + e_2d.slice(indices, sizes) = m; + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 3; k++) { + for (int l = 0; l < 2; l++) { + EXPECT_EQ(e_t(i, j, k, l), i * 2.0f); + LOG(INFO) << i << "," << j << "," << k << "," << l + << " &e_t(i, j, k, l): " << &e_t(i, j, k, l) << " = " + << e_t(i, j, k, l); + } + } + } + } +} + +TEST(Tensor_String, Simple) { + Tensor t = test::AsTensor( + {"hello", "world", "machine", "learning", "new", "york"}, + TensorShape({3, 2})); + auto s = t.shape(); + ASSERT_EQ(s.dims(), 2); + ASSERT_EQ(s.dim_size(0), 3); + ASSERT_EQ(s.dim_size(1), 2); + auto m = t.matrix(); + EXPECT_EQ(t.TotalBytes(), 3 * 2 * sizeof(string) + 5 + 5 + 7 + 8 + 3 + 4); + + EXPECT_EQ(m(0, 0), "hello"); + EXPECT_EQ(m(0, 1), "world"); + EXPECT_EQ(m(1, 0), "machine"); + EXPECT_EQ(m(1, 1), "learning"); + EXPECT_EQ(m(2, 0), "new"); + EXPECT_EQ(m(2, 1), "york"); + + TestCopies(t); +} + +TEST(Tensor_Float, SimpleWithHelper) { + Tensor t1 = test::AsTensor({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() * 2.0f; + Tensor t3 = test::AsTensor({0, 2, 4, 6, 8, 10}, t1.shape()); + test::ExpectTensorEqual(t2, t3); +} + +TEST(Tensor_Int32, SimpleWithHelper) { + Tensor t1 = test::AsTensor({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() * 2; + Tensor t3 = test::AsTensor({0, 2, 4, 6, 8, 10}, t1.shape()); + test::ExpectTensorEqual(t2, t3); +} + +TEST(Tensor_QInt8, SimpleWithHelper) { + Tensor t1 = test::AsTensor({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() + qint8(-2); + Tensor t3 = test::AsTensor({-2, -1, 0, 1, 2, 3}, {2, 3}); + test::ExpectTensorEqual(t2, t3); +} + +TEST(Tensor_QUInt8, SimpleWithHelper) { + Tensor t1 = test::AsTensor({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() + quint8(2); + Tensor t3 = test::AsTensor({2, 3, 4, 5, 6, 7}, {2, 3}); + test::ExpectTensorEqual(t2, t3); +} + +TEST(Tensor_Int64, SimpleWithHelper) { + Tensor t1 = test::AsTensor( + {0LL << 48, 1LL << 48, 2LL << 48, 3LL << 48, 4LL << 48, 5LL << 48}, + {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() * static_cast(2); + Tensor t3 = test::AsTensor( + {0LL << 48, 2LL << 48, 4LL << 48, 6LL << 48, 8LL << 48, 10LL << 48}, + {2, 3}); + test::ExpectTensorEqual(t2, t3); +} + +TEST(Tensor_String, SimpleWithHelper) { + Tensor t1 = test::AsTensor({"0", "1", "2", "3", "4", "5"}, {2, 3}); + Tensor t2(DT_STRING, {2, 3}); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + t2.matrix()(i, j) = strings::StrCat(i * 3 + j); + } + } + + // Test with helper. + test::ExpectTensorEqual(t1, t2); +} + +TEST(Tensor_Bool, SimpleWithHelper) { + Tensor t1 = + test::AsTensor({false, true, false, true, false, true}, {2, 3}); + + Tensor t2(DT_BOOL, {2, 3}); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + t2.matrix()(i, j) = (((i + j) % 2) != 0); + } + } + + // Test with helper. + test::ExpectTensorEqual(t1, t2); +} + +TEST(Tensor_Complex, Simple) { + Tensor t(DT_COMPLEX64, {4, 5, 3, 7}); + t.flat().setRandom(); + TestCopies(t); +} + +TEST(Tensor_Complex, SimpleWithHelper) { + { + Tensor t1 = test::AsTensor({0, + {1, 1}, + complex64(2), + complex64(3, 3), + complex64(0, 4), + complex64(2, 5)}, + {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() * complex64(0, 2); + Tensor t3 = test::AsTensor( + {0, {-2, 2}, {0, 4}, {-6, 6}, {-8, 0}, {-10, 4}}, + // shape + {2, 3}); + test::ExpectTensorEqual(t2, t3); + } + + // Does some numeric operations for complex numbers. + { + const float PI = std::acos(-1); + const complex64 rotate_45 = std::polar(1.0f, PI / 4); + + // x contains all the 8-th root of unity. + Tensor x(DT_COMPLEX64, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + x.vec()(i) = std::pow(rotate_45, i); + } + + // Shift the roots by 45 degree. + Tensor y(DT_COMPLEX64, TensorShape({8})); + y.vec() = x.vec() * rotate_45; + Tensor y_expected(DT_COMPLEX64, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + y_expected.vec()(i) = std::pow(rotate_45, i + 1); + } + test::ExpectTensorNear(y, y_expected, 1e-5); + + // Raise roots to the power of 8. + Tensor z(DT_COMPLEX64, TensorShape({8})); + z.vec() = x.vec().pow(8); + Tensor z_expected(DT_COMPLEX64, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + z_expected.vec()(i) = 1; + } + test::ExpectTensorNear(z, z_expected, 1e-5); + } +} + +// On the alignment. +// +// As of 2015/8, tensorflow::Tensor allocates its buffer with 32-byte +// alignment. Tensor::tensor/flat/vec/matrix methods requires the the +// buffer satisfies Eigen::Aligned (e.g., 16-bytes aligned usually, +// and 32-bytes for AVX). Tensor::Slice requires the caller to ensure +// its result is aligned if the caller intends to use those methods. +// In this test case, we simply make sure each slice is 32-byte +// aligned: sizeof(float) * 4 * 2 = 32. +TEST(Tensor, Slice_Basic) { + Tensor saved; + { // General + Tensor x(DT_FLOAT, TensorShape({10, 4, 34})); + // Fills in known values. + for (int i = 0; i < 10; ++i) { + x.Slice(i, i + 1).flat().setConstant(i * 1.f); + } + // A simple slice along dim0. + Tensor y = x.Slice(4, 8); + EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 4, 34}))); + auto tx = x.tensor(); + auto ty = y.tensor(); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 34; ++k) { + EXPECT_EQ(ty(i, j, k), 4.0 + i); + EXPECT_EQ(&tx(4 + i, j, k), &ty(i, j, k)); + } + } + } + // A simple slice equivalent to identity. + TestCopies(y); + y = x.Slice(0, 10); + test::ExpectTensorEqual(x, y); + EXPECT_EQ(x.flat().data(), y.flat().data()); + + // A slice of a slice. + auto z = x.Slice(4, 8).Slice(2, 3); + auto tz = z.tensor(); + EXPECT_EQ(1, z.dim_size(0)); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 34; ++k) { + EXPECT_EQ(tz(0, j, k), 6.0); + } + } + + // x and y will be out of scope. But 'saved' should be alive. + saved = z; + } + { + EXPECT_EQ(1, saved.dim_size(0)); + auto tsaved = saved.tensor(); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 34; ++k) { + EXPECT_EQ(tsaved(0, j, k), 6.0); + } + } + } + { // Empty + Tensor x(DT_FLOAT, TensorShape({10, 0, 34})); + x.flat().setRandom(); + Tensor y = x.Slice(4, 8); + EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 0, 34}))); + } + + { + // Test unaligned access via a Slice. + Tensor x(DT_FLOAT, TensorShape({30})); + x.flat().setConstant(0.0); + + // Take an unaligned slice. + Tensor y = x.Slice(1, 13); + y.unaligned_flat().setConstant(1.0); + for (int64 i = 0; i < y.NumElements(); ++i) { + EXPECT_EQ(1.0, y.unaligned_flat()(i)); + } + } +} + +static void BM_CreateAndDestroy(int iters) { + TensorShape shape({10, 20}); + while (--iters) { + Tensor t(DT_FLOAT, shape); + } +} +BENCHMARK(BM_CreateAndDestroy); + +static void BM_Assign(int iters) { + Tensor a(DT_FLOAT, TensorShape({10, 20})); + Tensor b(DT_FLOAT, TensorShape({10, 20})); + bool a_to_b = true; + while (--iters) { + if (a_to_b) { + b = a; + } else { + a = b; + } + a_to_b = !a_to_b; + } +} +BENCHMARK(BM_Assign); + +// Ensure tensor_data() works on empty tensors +TEST(Tensor, EmptyTensorData) { + Tensor empty; + EXPECT_EQ(empty.tensor_data().size(), 0); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_testutil.cc b/tensorflow/core/framework/tensor_testutil.cc new file mode 100644 index 0000000000..b6cd12a864 --- /dev/null +++ b/tensorflow/core/framework/tensor_testutil.cc @@ -0,0 +1,43 @@ +#include +#include "tensorflow/core/framework/tensor_testutil.h" + +namespace tensorflow { +namespace test { + +template +bool IsClose(const T& x, const T& y, double atol, double rtol) { + return fabs(x - y) < atol + rtol * fabs(x); +} + +template +void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) { + auto Tx = x.flat(); + auto Ty = y.flat(); + for (int i = 0; i < Tx.size(); ++i) { + if (!IsClose(Tx(i), Ty(i), atol, rtol)) { + LOG(ERROR) << "x = " << x.DebugString(); + LOG(ERROR) << "y = " << y.DebugString(); + LOG(ERROR) << "atol = " << atol << " rtol = " << rtol + << " tol = " << atol + rtol * std::fabs(Tx(i)); + EXPECT_TRUE(false) << i << "-th element is not close " << Tx(i) << " vs. " + << Ty(i); + } + } +} + +void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) { + internal::AssertSameTypeDims(x, y); + switch (x.dtype()) { + case DT_FLOAT: + ExpectClose(x, y, atol, rtol); + break; + case DT_DOUBLE: + ExpectClose(x, y, atol, rtol); + break; + default: + LOG(FATAL) << "Unexpected type : " << DataTypeString(x.dtype()); + } +} + +} // end namespace test +} // end namespace tensorflow diff --git a/tensorflow/core/framework/tensor_testutil.h b/tensorflow/core/framework/tensor_testutil.h new file mode 100644 index 0000000000..53d6da0fb2 --- /dev/null +++ b/tensorflow/core/framework/tensor_testutil.h @@ -0,0 +1,189 @@ +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" +#include + +namespace tensorflow { +namespace test { + +// Constructs a scalar tensor with 'val'. +template +Tensor AsScalar(const T& val) { + Tensor ret(DataTypeToEnum::value, {}); + ret.scalar()() = val; + return ret; +} + +// Constructs a flat tensor with 'vals'. +template +Tensor AsTensor(gtl::ArraySlice vals) { + Tensor ret(DataTypeToEnum::value, {static_cast(vals.size())}); + std::copy_n(vals.data(), vals.size(), ret.flat().data()); + return ret; +} + +// Constructs a tensor of "shape" with values "vals". +template +Tensor AsTensor(gtl::ArraySlice vals, const TensorShape& shape) { + Tensor ret; + CHECK(ret.CopyFrom(AsTensor(vals), shape)); + return ret; +} + +// Fills in '*tensor' with 'vals'. E.g., +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillValues(&x, {11, 21, 21, 22}); +template +void FillValues(Tensor* tensor, gtl::ArraySlice vals) { + auto flat = tensor->flat(); + CHECK_EQ(flat.size(), vals.size()); + if (flat.size() > 0) { + std::copy_n(vals.data(), vals.size(), flat.data()); + } +} + +// Fills in '*tensor' with a sequence of value of val, val+1, val+2, ... +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillIota(&x, 1.0); +template +void FillIota(Tensor* tensor, const T& val) { + auto flat = tensor->flat(); + std::iota(flat.data(), flat.data() + flat.size(), val); +} + +// Fills in '*tensor' with a sequence of value of fn(0), fn(1), ... +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillFn(&x, [](int i)->float { return i*i; }); +template +void FillFn(Tensor* tensor, std::function fn) { + auto flat = tensor->flat(); + for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i); +} + +// Expects "x" and "y" are tensors of the same type, same shape, and +// identical values. +template +void ExpectTensorEqual(const Tensor& x, const Tensor& y); + +// Expects "x" and "y" are tensors of the same type, same shape, and +// approxmiate equal values, each within "abs_err". +template +void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err); + +// Expects "x" and "y" are tensors of the same type (float or double), +// same shape and element-wise difference between x and y is no more +// than atol + rtol * abs(x). +void ExpectClose(const Tensor& x, const Tensor& y, double atol = 1e-6, + double rtol = 1e-6); + +// Implementation details. + +namespace internal { + +template +struct is_floating_point_type { + static const bool value = std::is_same::value || + std::is_same::value || + std::is_same >::value || + std::is_same >::value; +}; + +template +static void ExpectEqual(const T& a, const T& b) { + EXPECT_EQ(a, b); +} + +template <> +void ExpectEqual(const float& a, const float& b) { + EXPECT_FLOAT_EQ(a, b); +} + +template <> +void ExpectEqual(const double& a, const double& b) { + EXPECT_DOUBLE_EQ(a, b); +} + +template <> +void ExpectEqual(const complex64& a, const complex64& b) { + EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b; + EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b; +} + +inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), y.dtype()); + ASSERT_TRUE(x.IsSameSize(y)) + << "x.shape [" << x.shape().DebugString() << "] vs " + << "y.shape [ " << y.shape().DebugString() << "]"; +} + +template ::value> +struct Expector; + +template +struct Expector { + static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } + + static void Equal(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + AssertSameTypeDims(x, y); + auto a = x.flat(); + auto b = y.flat(); + for (int i = 0; i < a.size(); ++i) { + ExpectEqual(a(i), b(i)); + } + } +}; + +// Partial specialization for float and double. +template +struct Expector { + static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } + + static void Equal(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + AssertSameTypeDims(x, y); + auto a = x.flat(); + auto b = y.flat(); + for (int i = 0; i < a.size(); ++i) { + ExpectEqual(a(i), b(i)); + } + } + + static void Near(const T& a, const T& b, const double abs_err) { + if (a != b) { // Takes care of inf. + EXPECT_LE(std::abs(a - b), abs_err) << "a = " << a << " b = " << b; + } + } + + static void Near(const Tensor& x, const Tensor& y, const double abs_err) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + AssertSameTypeDims(x, y); + auto a = x.flat(); + auto b = y.flat(); + for (int i = 0; i < a.size(); ++i) { + Near(a(i), b(i), abs_err); + } + } +}; + +} // namespace internal + +template +void ExpectTensorEqual(const Tensor& x, const Tensor& y) { + internal::Expector::Equal(x, y); +} + +template +void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) { + static_assert(internal::is_floating_point_type::value, + "T is not a floating point types."); + internal::Expector::Near(x, y, abs_err); +} + +} // namespace test +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ diff --git a/tensorflow/core/framework/tensor_types.h b/tensorflow/core/framework/tensor_types.h new file mode 100644 index 0000000000..077d86d442 --- /dev/null +++ b/tensorflow/core/framework/tensor_types.h @@ -0,0 +1,92 @@ +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +// Helper to define Tensor types given that the scalar is of type T. +template +struct TTypes { + // Rank- tensor of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> Tensor; + typedef Eigen::TensorMap, + Eigen::Aligned> ConstTensor; + + // Unaligned Rank- tensor of scalar type T. + typedef Eigen::TensorMap > + UnalignedTensor; + typedef Eigen::TensorMap > + UnalignedConstTensor; + + typedef Eigen::TensorMap, + Eigen::Aligned> Tensor32Bit; + + // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. + typedef Eigen::TensorMap< + Eigen::TensorFixedSize, Eigen::RowMajor>, + Eigen::Aligned> Scalar; + typedef Eigen::TensorMap< + Eigen::TensorFixedSize, Eigen::RowMajor>, + Eigen::Aligned> ConstScalar; + + // Unaligned Scalar tensor of scalar type T. + typedef Eigen::TensorMap, Eigen::RowMajor> > UnalignedScalar; + typedef Eigen::TensorMap, Eigen::RowMajor> > UnalignedConstScalar; + + // Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap, Eigen::Aligned> + Flat; + typedef Eigen::TensorMap, + Eigen::Aligned> ConstFlat; + typedef Eigen::TensorMap, Eigen::Aligned> + Vec; + typedef Eigen::TensorMap, + Eigen::Aligned> ConstVec; + + // Unaligned Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap > UnalignedFlat; + typedef Eigen::TensorMap > + UnalignedConstFlat; + typedef Eigen::TensorMap > UnalignedVec; + typedef Eigen::TensorMap > + UnalignedConstVec; + + // Rank-2 tensor (matrix) of scalar type T. + typedef Eigen::TensorMap, Eigen::Aligned> + Matrix; + typedef Eigen::TensorMap, + Eigen::Aligned> ConstMatrix; + + // Unaligned Rank-2 tensor (matrix) of scalar type T. + typedef Eigen::TensorMap > + UnalignedMatrix; + typedef Eigen::TensorMap > + UnalignedConstMatrix; +}; + +typedef typename TTypes::Tensor32Bit::Index Index32; + +template +Eigen::DSizes To32BitDims(const DSizes& in) { + Eigen::DSizes out; + for (int i = 0; i < DSizes::count; ++i) { + out[i] = in[i]; + } + return out; +} + +template +typename TTypes::Tensor32Bit +To32Bit(TensorType in) { + typedef typename TTypes::Tensor32Bit RetType; + return RetType(in.data(), To32BitDims(in.dimensions())); +} + +} // namespace tensorflow +#endif // TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_ diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc new file mode 100644 index 0000000000..7353191c74 --- /dev/null +++ b/tensorflow/core/framework/tensor_util.cc @@ -0,0 +1,28 @@ +#include "tensorflow/core/framework/tensor_util.h" + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace tensor { + +Tensor DeepCopy(const Tensor& other) { + Tensor tmp = Tensor(other.dtype(), other.shape()); + if (DataTypeCanUseMemcpy(other.dtype())) { + StringPiece other_data = other.tensor_data(); + + // We use StringPiece as a convenient map over the tensor buffer, + // but we cast the type to get to the underlying buffer to do the + // copy. + StringPiece tmp_data = tmp.tensor_data(); + memcpy(const_cast(tmp_data.data()), other_data.data(), + other_data.size()); + } else { + CHECK_EQ(DT_STRING, other.dtype()); + tmp.flat() = other.flat(); + } + return tmp; +} + +} // namespace tensor +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h new file mode 100644 index 0000000000..a8dde1d0ca --- /dev/null +++ b/tensorflow/core/framework/tensor_util.h @@ -0,0 +1,21 @@ +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_ + +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace tensor { + +// DeepCopy returns a tensor whose contents are a deep copy of the +// contents of 'other'. This function is intended only for +// convenience, not speed. +// +// REQUIRES: 'other' must point to data stored in CPU memory. +// REQUIRES: 'other' must be a Tensor of a copy-able type if +// 'other' is not appropriately memory-aligned. +Tensor DeepCopy(const Tensor& other); + +} // namespace tensor +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_ diff --git a/tensorflow/core/framework/tensor_util_test.cc b/tensorflow/core/framework/tensor_util_test.cc new file mode 100644 index 0000000000..fef7468151 --- /dev/null +++ b/tensorflow/core/framework/tensor_util_test.cc @@ -0,0 +1,124 @@ +#include "tensorflow/core/framework/tensor_util.h" + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor.h" +#include + +namespace tensorflow { +namespace { + +TEST(TensorUtil, DeepCopy0d) { + Tensor x(DT_FLOAT, TensorShape({})); + x.scalar()() = 10.0; + + // Make y a deep copy of x and then change it. + Tensor y = tensor::DeepCopy(x); + y.scalar()() = 20.0; + + // x doesn't change + EXPECT_EQ(10.0, x.scalar()()); + + // Change x. + x.scalar()() = 30.0; + + // Y doesn't change. + EXPECT_EQ(20.0, y.scalar()()); + + Tensor z = tensor::DeepCopy(y); + + // Change y. + y.scalar()() = 40.0; + + // The final states should all be different. + EXPECT_EQ(20.0, z.scalar()()); + EXPECT_EQ(30.0, x.scalar()()); + EXPECT_EQ(40.0, y.scalar()()); + + // Should have the same shape and type. + EXPECT_EQ(TensorShape({}), x.shape()); + EXPECT_EQ(TensorShape({}), y.shape()); + EXPECT_EQ(TensorShape({}), z.shape()); + + EXPECT_EQ(DT_FLOAT, x.dtype()); + EXPECT_EQ(DT_FLOAT, y.dtype()); + EXPECT_EQ(DT_FLOAT, z.dtype()); +} + +TEST(TensorUtil, DeepCopy) { + Tensor x(DT_FLOAT, TensorShape({1})); + x.flat()(0) = 10.0; + + // Make y a deep copy of x and then change it. + Tensor y = tensor::DeepCopy(x); + y.flat()(0) = 20.0; + + // x doesn't change + EXPECT_EQ(10.0, x.flat()(0)); + + // Change x. + x.flat()(0) = 30.0; + + // Y doesn't change. + EXPECT_EQ(20.0, y.flat()(0)); + + Tensor z = tensor::DeepCopy(y); + + // Change y. + y.flat()(0) = 40.0; + + // The final states should all be different. + EXPECT_EQ(20.0, z.flat()(0)); + EXPECT_EQ(30.0, x.flat()(0)); + EXPECT_EQ(40.0, y.flat()(0)); + + // Should have the same shape and type. + EXPECT_EQ(TensorShape({1}), x.shape()); + EXPECT_EQ(TensorShape({1}), y.shape()); + EXPECT_EQ(TensorShape({1}), z.shape()); + + EXPECT_EQ(DT_FLOAT, x.dtype()); + EXPECT_EQ(DT_FLOAT, y.dtype()); + EXPECT_EQ(DT_FLOAT, z.dtype()); + + // Test string deep copy + Tensor str1(DT_STRING, TensorShape({2})); + str1.flat()(0) = "foo1"; + str1.flat()(1) = "foo2"; + Tensor str2 = tensor::DeepCopy(str1); + str2.flat()(0) = "bar1"; + str2.flat()(1) = "bar2"; + EXPECT_NE(str2.flat()(0), str1.flat()(0)); +} + +TEST(TensorUtil, DeepCopySlice) { + Tensor x(DT_INT32, TensorShape({10})); + x.flat().setConstant(1); + + // Slice 'x' -- y still refers to the same buffer. + Tensor y = x.Slice(2, 6); + + // Do a deep copy of y, which is a slice. + Tensor z = tensor::DeepCopy(y); + + // Set x to be different. + x.flat().setConstant(2); + + EXPECT_EQ(TensorShape({10}), x.shape()); + EXPECT_EQ(TensorShape({4}), y.shape()); + EXPECT_EQ(TensorShape({4}), z.shape()); + EXPECT_EQ(DT_INT32, x.dtype()); + EXPECT_EQ(DT_INT32, y.dtype()); + EXPECT_EQ(DT_INT32, z.dtype()); + + // x and y should now all be '2', but z should be '1'. + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(2, x.flat()(i)); + } + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(2, y.unaligned_flat()(i)); + EXPECT_EQ(1, z.flat()(i)); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/tracking_allocator.cc b/tensorflow/core/framework/tracking_allocator.cc new file mode 100644 index 0000000000..78311ded19 --- /dev/null +++ b/tensorflow/core/framework/tracking_allocator.cc @@ -0,0 +1,100 @@ +#include "tensorflow/core/framework/tracking_allocator.h" + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +TrackingAllocator::TrackingAllocator(Allocator* allocator) + : allocator_(allocator), + ref_(1), + allocated_(0), + high_watermark_(0), + total_bytes_(0) {} + +void* TrackingAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { + void* ptr = allocator_->AllocateRaw(alignment, num_bytes); + // If memory is exhausted AllocateRaw returns nullptr, and we should + // pass this through to the caller + if (nullptr == ptr) { + return ptr; + } + if (allocator_->TracksAllocationSizes()) { + size_t allocated_bytes = allocator_->AllocatedSize(ptr); + { + mutex_lock lock(mu_); + allocated_ += allocated_bytes; + high_watermark_ = std::max(high_watermark_, allocated_); + total_bytes_ += allocated_bytes; + ++ref_; + } + } else { + mutex_lock lock(mu_); + total_bytes_ += num_bytes; + ++ref_; + } + return ptr; +} + +void TrackingAllocator::DeallocateRaw(void* ptr) { + // freeing a null ptr is a no-op + if (nullptr == ptr) { + return; + } + bool should_delete; + // fetch the following outside the lock in case the call to + // AllocatedSize is slow + bool tracks_allocation_sizes = allocator_->TracksAllocationSizes(); + size_t allocated_bytes = 0; + if (tracks_allocation_sizes) { + allocated_bytes = allocator_->AllocatedSize(ptr); + } + Allocator* allocator = allocator_; + { + mutex_lock lock(mu_); + if (tracks_allocation_sizes) { + CHECK_GE(allocated_, allocated_bytes); + allocated_ -= allocated_bytes; + } + should_delete = UnRef(); + } + allocator->DeallocateRaw(ptr); + if (should_delete) { + delete this; + } +} + +bool TrackingAllocator::TracksAllocationSizes() { + return allocator_->TracksAllocationSizes(); +} + +size_t TrackingAllocator::RequestedSize(void* ptr) { + return allocator_->RequestedSize(ptr); +} + +size_t TrackingAllocator::AllocatedSize(void* ptr) { + return allocator_->AllocatedSize(ptr); +} + +std::pair TrackingAllocator::GetSizesAndUnRef() { + size_t high_watermark; + size_t total_bytes; + bool should_delete; + { + mutex_lock lock(mu_); + high_watermark = high_watermark_; + total_bytes = total_bytes_; + should_delete = UnRef(); + } + if (should_delete) { + delete this; + } + return std::make_pair(total_bytes, high_watermark); +} + +bool TrackingAllocator::UnRef() { + CHECK_GE(ref_, 1); + --ref_; + return (ref_ == 0); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h new file mode 100644 index 0000000000..f809e3822c --- /dev/null +++ b/tensorflow/core/framework/tracking_allocator.h @@ -0,0 +1,80 @@ +#ifndef TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_ +#define TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_ + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// TrackingAllocator is a wrapper for an Allocator. It keeps a running +// count of the number of bytes allocated through the wrapper. It is +// used by the Executor to "charge" allocations to particular Op +// executions. Each Op gets a separate TrackingAllocator wrapper +// around the underlying allocator. +// +// The implementation assumes the invariant that all calls to +// AllocateRaw by an Op (or work items spawned by the Op) will occur +// before the Op's Compute method returns. Thus the high watermark is +// established once Compute returns. +// +// DeallocateRaw can be called long after the Op has finished, +// e.g. when an output tensor is deallocated, and the wrapper cannot +// be deleted until the last of these calls has occurred. The +// TrackingAllocator keeps track of outstanding calls using a +// reference count, and deletes itself once the last call has been +// received and the high watermark has been retrieved. +class TrackingAllocator : public Allocator { + public: + explicit TrackingAllocator(Allocator* allocator); + string Name() override { return allocator_->Name(); } + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + bool TracksAllocationSizes() override; + size_t RequestedSize(void* ptr) override; + size_t AllocatedSize(void* ptr) override; + + // If the underlying allocator tracks allocation sizes, this returns + // a pair where the first value is the total number of bytes + // allocated through this wrapper, and the second value is the high + // watermark of bytes allocated through this wrapper. If the + // underlying allocator does not track allocation sizes the first + // value is the total number of bytes requested through this wrapper + // and the second is 0. + // + // After GetSizesAndUnref is called, the only further calls allowed + // on this wrapper are calls to DeallocateRaw with pointers that + // were allocated by this wrapper and have not yet been + // deallocated. After this call completes and all allocated pointers + // have been deallocated the wrapper will delete itself. + std::pair GetSizesAndUnRef(); + + private: + ~TrackingAllocator() override {} + bool UnRef() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + Allocator* allocator_; // not owned. + mutex mu_; + // the number of calls to AllocateRaw that have not yet been matched + // by a corresponding call to DeAllocateRaw, plus 1 if the Executor + // has not yet read out the high watermark. + int ref_ GUARDED_BY(mu_); + // the current number of outstanding bytes that have been allocated + // by this wrapper, or 0 if the underlying allocator does not track + // allocation sizes. + size_t allocated_ GUARDED_BY(mu_); + // the maximum number of outstanding bytes that have been allocated + // by this wrapper, or 0 if the underlying allocator does not track + // allocation sizes. + size_t high_watermark_ GUARDED_BY(mu_); + // the total number of bytes that have been allocated by this + // wrapper if the underlying allocator tracks allocation sizes, + // otherwise the total number of bytes that have been requested by + // this allocator. + size_t total_bytes_ GUARDED_BY(mu_); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_ diff --git a/tensorflow/core/framework/tracking_allocator_test.cc b/tensorflow/core/framework/tracking_allocator_test.cc new file mode 100644 index 0000000000..90ce851775 --- /dev/null +++ b/tensorflow/core/framework/tracking_allocator_test.cc @@ -0,0 +1,115 @@ +#include "tensorflow/core/framework/tracking_allocator.h" + +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/logging.h" +#include + +namespace tensorflow { + +class TestableSizeTrackingAllocator : public Allocator { + public: + string Name() override { return "test"; } + void* AllocateRaw(size_t /*alignment*/, size_t num_bytes) override { + void* ptr = malloc(num_bytes); + size_map_[ptr] = num_bytes; + return ptr; + } + void DeallocateRaw(void* ptr) override { + const auto& iter = size_map_.find(ptr); + EXPECT_NE(size_map_.end(), iter); + size_map_.erase(iter); + free(ptr); + } + bool TracksAllocationSizes() override { return true; } + size_t RequestedSize(void* ptr) override { + const auto& iter = size_map_.find(ptr); + EXPECT_NE(size_map_.end(), iter); + return iter->second; + } + + private: + std::unordered_map size_map_; +}; + +class NoMemoryAllocator : public Allocator { + public: + string Name() override { return "test"; } + void* AllocateRaw(size_t /*alignment*/, size_t num_bytes) override { + return nullptr; + } + void DeallocateRaw(void* ptr) override {} + bool TracksAllocationSizes() override { return true; } +}; + +TEST(TrackingAllocatorTest, SimpleNoTracking) { + Allocator* a = cpu_allocator(); + + EXPECT_FALSE(a->TracksAllocationSizes()); + + TrackingAllocator* ta = new TrackingAllocator(a); + + void* p1 = ta->AllocateRaw(4, 4); + ta->Deallocate(p1); + void* p2 = ta->AllocateRaw(4, 12); + + std::pair sizes = ta->GetSizesAndUnRef(); + + EXPECT_EQ(16, sizes.first); + EXPECT_EQ(0, sizes.second); + + ta->Deallocate(p2); +} + +TEST(TrackingAllocatorTest, SimpleTracking) { + TestableSizeTrackingAllocator a = TestableSizeTrackingAllocator(); + + EXPECT_TRUE(a.TracksAllocationSizes()); + + TrackingAllocator* ta = new TrackingAllocator(&a); + + void* p1 = ta->AllocateRaw(4, 12); + ta->Deallocate(p1); + void* p2 = ta->AllocateRaw(4, 4); + + std::pair sizes = ta->GetSizesAndUnRef(); + + EXPECT_EQ(16, sizes.first); + EXPECT_EQ(12, sizes.second); + + ta->Deallocate(p2); +} + +TEST(TrackingAllocatorTest, OutOfMemory) { + NoMemoryAllocator a; + + EXPECT_TRUE(a.TracksAllocationSizes()); + + TrackingAllocator* ta = new TrackingAllocator(&a); + + void* p1 = ta->AllocateRaw(4, 12); + EXPECT_EQ(nullptr, p1); + + std::pair sizes = ta->GetSizesAndUnRef(); + + EXPECT_EQ(0, sizes.first); + EXPECT_EQ(0, sizes.second); +} + +TEST(TrackingAllocatorTest, FreeNullPtr) { + NoMemoryAllocator a; + + EXPECT_TRUE(a.TracksAllocationSizes()); + + TrackingAllocator* ta = new TrackingAllocator(&a); + + ta->DeallocateRaw(nullptr); + + std::pair sizes = ta->GetSizesAndUnRef(); + + EXPECT_EQ(0, sizes.first); + EXPECT_EQ(0, sizes.second); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/type_traits.h b/tensorflow/core/framework/type_traits.h new file mode 100644 index 0000000000..d87b6ff49b --- /dev/null +++ b/tensorflow/core/framework/type_traits.h @@ -0,0 +1,69 @@ +#ifndef TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_ +#define TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_ + +#include +#include + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Functions to define quantization attribute of types. +struct true_type { + static const bool value = true; +}; +struct false_type { + static const bool value = false; +}; + +// Default is_quantized is false. +template +struct is_quantized : false_type {}; + +// Specialize the quantized types. +template <> +struct is_quantized : true_type {}; +template <> +struct is_quantized : true_type {}; +template <> +struct is_quantized : true_type {}; + +// All types not specialized are marked invalid. +template +struct IsValidDataType { + static constexpr bool value = false; +}; + +// Extra validity checking; not part of public API. +struct TestIsValidDataType { + static_assert(IsValidDataType::value, "Incorrect impl for int64"); + static_assert(IsValidDataType::value, "Incorrect impl for int32"); +}; + +} // namespace tensorflow + +// Define numeric limits for our quantized as subclasses of the +// standard types. +namespace std { +template <> +class numeric_limits + : public numeric_limits {}; +template <> +class numeric_limits + : public numeric_limits {}; +template <> +class numeric_limits + : public numeric_limits {}; + +// Specialize is_signed for quantized types. +template <> +struct is_signed : public is_signed {}; +template <> +struct is_signed : public is_signed {}; +template <> +struct is_signed : public is_signed {}; + +} // namespace std + +#endif // TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_ diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc new file mode 100644 index 0000000000..01b9fca3b6 --- /dev/null +++ b/tensorflow/core/framework/types.cc @@ -0,0 +1,210 @@ +#include "tensorflow/core/framework/types.h" + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +bool DeviceType::operator<(const DeviceType& other) const { + return type_ < other.type_; +} + +bool DeviceType::operator==(const DeviceType& other) const { + return type_ == other.type_; +} + +std::ostream& operator<<(std::ostream& os, const DeviceType& d) { + os << d.type(); + return os; +} + +const char* const DEVICE_CPU = "CPU"; +const char* const DEVICE_GPU = "GPU"; + +string DataTypeString(DataType dtype) { + if (IsRefType(dtype)) { + DataType non_ref = static_cast(dtype - kDataTypeRefOffset); + return strings::StrCat(DataTypeString(non_ref), "_ref"); + } + switch (dtype) { + case DT_INVALID: + return "INVALID"; + case DT_FLOAT: + return "float"; + case DT_DOUBLE: + return "double"; + case DT_INT32: + return "int32"; + case DT_UINT8: + return "uint8"; + case DT_INT16: + return "int16"; + case DT_INT8: + return "int8"; + case DT_STRING: + return "string"; + case DT_COMPLEX64: + return "complex64"; + case DT_INT64: + return "int64"; + case DT_BOOL: + return "bool"; + case DT_QINT8: + return "qint8"; + case DT_QUINT8: + return "quint8"; + case DT_QINT32: + return "qint32"; + case DT_BFLOAT16: + return "bfloat16"; + default: + LOG(FATAL) << "Unrecognized DataType enum value " << dtype; + return ""; + } +} + +bool DataTypeFromString(StringPiece sp, DataType* dt) { + if (sp.ends_with("_ref")) { + sp.remove_suffix(4); + DataType non_ref; + if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) { + *dt = static_cast(non_ref + kDataTypeRefOffset); + return true; + } else { + return false; + } + } + + if (sp == "float" || sp == "float32") { + *dt = DT_FLOAT; + return true; + } else if (sp == "double" || sp == "float64") { + *dt = DT_DOUBLE; + return true; + } else if (sp == "int32") { + *dt = DT_INT32; + return true; + } else if (sp == "uint8") { + *dt = DT_UINT8; + return true; + } else if (sp == "int16") { + *dt = DT_INT16; + return true; + } else if (sp == "int8") { + *dt = DT_INT8; + return true; + } else if (sp == "string") { + *dt = DT_STRING; + return true; + } else if (sp == "complex64") { + *dt = DT_COMPLEX64; + return true; + } else if (sp == "int64") { + *dt = DT_INT64; + return true; + } else if (sp == "bool") { + *dt = DT_BOOL; + return true; + } else if (sp == "qint8") { + *dt = DT_QINT8; + return true; + } else if (sp == "quint8") { + *dt = DT_QUINT8; + return true; + } else if (sp == "qint32") { + *dt = DT_QINT32; + return true; + } else if (sp == "bfloat16") { + *dt = DT_BFLOAT16; + return true; + } + return false; +} + +string DeviceTypeString(DeviceType device_type) { return device_type.type(); } + +string DataTypeSliceString(const DataTypeSlice types) { + string out; + for (auto it = types.begin(); it != types.end(); ++it) { + strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "), + DataTypeString(*it)); + } + return out; +} + +DataTypeVector AllTypes() { + return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, + DT_INT8, DT_STRING, DT_COMPLEX64, DT_INT64, DT_BOOL, + DT_QINT8, DT_QUINT8, DT_QINT32}; +} + +#ifndef __ANDROID__ + +DataTypeVector RealNumberTypes() { + return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, DT_INT8}; +} + +DataTypeVector QuantizedTypes() { return {DT_QINT8, DT_QUINT8, DT_QINT32}; } + +DataTypeVector RealAndQuantizedTypes() { + return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, + DT_INT16, DT_INT8, DT_QINT8, DT_QUINT8, DT_QINT32}; +} + +DataTypeVector NumberTypes() { + return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, + DT_INT8, DT_COMPLEX64, DT_QINT8, DT_QUINT8, DT_QINT32}; +} + +#else // __ANDROID__ + +DataTypeVector RealNumberTypes() { return {DT_FLOAT, DT_INT32}; } + +DataTypeVector NumberTypes() { + return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, DT_QINT32}; +} + +DataTypeVector QuantizedTypes() { return {DT_QINT8, DT_QUINT8, DT_QINT32}; } + +DataTypeVector RealAndQuantizedTypes() { + return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, DT_QINT32}; +} + +#endif // __ANDROID__ + +// TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying +// is_simple in tensor.cc (and possible choose a more general name?) +bool DataTypeCanUseMemcpy(DataType dt) { + switch (dt) { + case DT_FLOAT: + case DT_DOUBLE: + case DT_INT32: + case DT_UINT8: + case DT_INT16: + case DT_INT8: + case DT_COMPLEX64: + case DT_INT64: + case DT_BOOL: + case DT_QINT8: + case DT_QUINT8: + case DT_QINT32: + case DT_BFLOAT16: + return true; + default: + return false; + } +} + +bool DataTypeIsQuantized(DataType dt) { + switch (dt) { + case DT_QINT8: + case DT_QUINT8: + case DT_QINT32: + return true; + default: + return false; + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h new file mode 100644 index 0000000000..2d417cf076 --- /dev/null +++ b/tensorflow/core/framework/types.h @@ -0,0 +1,168 @@ +#ifndef TENSORFLOW_FRAMEWORK_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_TYPES_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" + +namespace tensorflow { + +// MemoryType is used to describe whether input or output Tensors of +// an OpKernel should reside in "Host memory" (e.g., CPU memory) or +// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU +// devices). +enum MemoryType { + DEVICE_MEMORY = 0, + HOST_MEMORY = 1, +}; + +// A DeviceType is just a string, but we wrap it up in a class to give +// some type checking as we're passing these around +class DeviceType { + public: + DeviceType(const char* type) // NOLINT(runtime/explicit) + : type_(type) {} + + explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {} + + const char* type() const { return type_.c_str(); } + + bool operator<(const DeviceType& other) const; + bool operator==(const DeviceType& other) const; + bool operator!=(const DeviceType& other) const { return !(*this == other); } + + private: + string type_; +}; +std::ostream& operator<<(std::ostream& os, const DeviceType& d); + +// Convenient constants that can be passed to a DeviceType constructor +extern const char* const DEVICE_CPU; // "CPU" +extern const char* const DEVICE_GPU; // "GPU" + +typedef gtl::InlinedVector MemoryTypeVector; + +typedef gtl::InlinedVector DataTypeVector; +typedef gtl::ArraySlice DataTypeSlice; + +typedef gtl::InlinedVector DeviceTypeVector; + +// Convert the enums to strings for errors: +string DataTypeString(DataType dtype); +string DeviceTypeString(DeviceType device_type); +string DataTypeSliceString(const DataTypeSlice dtypes); +inline string DataTypeVectorString(const DataTypeVector& dtypes) { + return DataTypeSliceString(dtypes); +} + +// If "sp" names a valid type, store it in "*dt" and return true. Otherwise, +// return false. +bool DataTypeFromString(StringPiece sp, DataType* dt); + +// DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc. +enum { kDataTypeRefOffset = 100 }; +inline bool IsRefType(DataType dtype) { + return dtype > static_cast(kDataTypeRefOffset); +} +inline DataType MakeRefType(DataType dtype) { + DCHECK(!IsRefType(dtype)); + return static_cast(dtype + kDataTypeRefOffset); +} +inline DataType RemoveRefType(DataType dtype) { + DCHECK(IsRefType(dtype)); + return static_cast(dtype - kDataTypeRefOffset); +} +inline DataType BaseType(DataType dtype) { + return IsRefType(dtype) ? RemoveRefType(dtype) : dtype; +} + +// Returns true if the actual type is the same as or ref of the expected type. +inline bool TypesCompatible(DataType expected, DataType actual) { + return expected == actual || expected == BaseType(actual); +} + +// Does not include _ref types. +DataTypeVector AllTypes(); + +// Return the list of all numeric types. +// NOTE: On Android, we only include the float and int32 types for now. +DataTypeVector RealNumberTypes(); // Types that support '<' and '>'. +DataTypeVector NumberTypes(); // Includes complex and quantized types. + +DataTypeVector QuantizedTypes(); +DataTypeVector RealAndQuantizedTypes(); // Types that support '<' and + // '>', including quantized + // types + +// Validates type T for whether it is a supported DataType. +template +struct IsValidDataType; + +// DataTypeToEnum::v() and DataTypeToEnum::value are the DataType +// constants for T, e.g. DataTypeToEnum::v() is DT_FLOAT. +template +struct DataTypeToEnum { + static_assert(IsValidDataType::value, "Specified Data Type not supported"); +}; // Specializations below + +// EnumToDataType::Type is the type for DataType constant VALUE, e.g. +// EnumToDataType::Type is float. +template +struct EnumToDataType {}; // Specializations below + +// Template specialization for both DataTypeToEnum and EnumToDataType. +#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \ + template <> \ + struct DataTypeToEnum { \ + static DataType v() { return ENUM; } \ + static DataType ref() { return MakeRefType(ENUM); } \ + static constexpr DataType value = ENUM; \ + }; \ + template <> \ + struct IsValidDataType { \ + static constexpr bool value = true; \ + }; \ + template <> \ + struct EnumToDataType { \ + typedef TYPE Type; \ + } + +// We use Eigen's QInt implementations for our quantized int types. +typedef Eigen::QInt8 qint8; +typedef Eigen::QUInt8 quint8; +typedef Eigen::QInt32 qint32; + +MATCH_TYPE_AND_ENUM(float, DT_FLOAT); +MATCH_TYPE_AND_ENUM(double, DT_DOUBLE); +MATCH_TYPE_AND_ENUM(int32, DT_INT32); +MATCH_TYPE_AND_ENUM(uint8, DT_UINT8); +MATCH_TYPE_AND_ENUM(int16, DT_INT16); +MATCH_TYPE_AND_ENUM(int8, DT_INT8); +MATCH_TYPE_AND_ENUM(string, DT_STRING); +MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64); +MATCH_TYPE_AND_ENUM(int64, DT_INT64); +MATCH_TYPE_AND_ENUM(bool, DT_BOOL); +MATCH_TYPE_AND_ENUM(qint8, DT_QINT8); +MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8); +MATCH_TYPE_AND_ENUM(qint32, DT_QINT32); +MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16); + +#undef MATCH_TYPE_AND_ENUM + +bool DataTypeCanUseMemcpy(DataType dt); + +bool DataTypeIsQuantized(DataType dt); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TYPES_H_ diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto new file mode 100644 index 0000000000..e5dc9c45a0 --- /dev/null +++ b/tensorflow/core/framework/types.proto @@ -0,0 +1,48 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + + // TODO(josh11b): DT_GENERIC_PROTO = ??; + // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? DT_UINT16? + // TODO(zhifengc): DT_COMPLEX128 (double-precision complex)? + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; +} diff --git a/tensorflow/core/framework/types_test.cc b/tensorflow/core/framework/types_test.cc new file mode 100644 index 0000000000..eb92600397 --- /dev/null +++ b/tensorflow/core/framework/types_test.cc @@ -0,0 +1,117 @@ +#include "tensorflow/core/framework/types.h" + +#include +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace { + +TEST(TypesTest, DeviceTypeName) { + EXPECT_EQ("CPU", DeviceTypeString(DeviceType(DEVICE_CPU))); + EXPECT_EQ("GPU", DeviceTypeString(DeviceType(DEVICE_GPU))); +} + +TEST(TypesTest, kDataTypeRefOffset) { + // Basic sanity check + EXPECT_EQ(DT_FLOAT + kDataTypeRefOffset, DT_FLOAT_REF); + + // Use the meta-data provided by proto2 to iterate through the basic + // types and validate that adding kDataTypeRefOffset gives the + // corresponding reference type. + const auto* enum_descriptor = DataType_descriptor(); + int e = DataType_MIN; + if (e == DT_INVALID) ++e; + int e_ref = e + kDataTypeRefOffset; + EXPECT_FALSE(DataType_IsValid(e_ref - 1)) + << "Reference enum " + << enum_descriptor->FindValueByNumber(e_ref - 1)->name() + << " without corresponding base enum with value " << e - 1; + for (; + DataType_IsValid(e) && DataType_IsValid(e_ref) && e_ref <= DataType_MAX; + ++e, ++e_ref) { + string enum_name = enum_descriptor->FindValueByNumber(e)->name(); + string enum_ref_name = enum_descriptor->FindValueByNumber(e_ref)->name(); + EXPECT_EQ(enum_name + "_REF", enum_ref_name) + << enum_name << "_REF should have value " << e_ref << " not " + << enum_ref_name; + // Validate DataTypeString() as well. + DataType dt_e = static_cast(e); + DataType dt_e_ref = static_cast(e_ref); + EXPECT_EQ(DataTypeString(dt_e) + "_ref", DataTypeString(dt_e_ref)); + + // Test DataTypeFromString reverse conversion + DataType dt_e2, dt_e2_ref; + EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e), &dt_e2)); + EXPECT_EQ(dt_e, dt_e2); + EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e_ref), &dt_e2_ref)); + EXPECT_EQ(dt_e_ref, dt_e2_ref); + } + ASSERT_FALSE(DataType_IsValid(e)) + << "Should define " << enum_descriptor->FindValueByNumber(e)->name() + << "_REF to be " << e_ref; + ASSERT_FALSE(DataType_IsValid(e_ref)) + << "Extra reference enum " + << enum_descriptor->FindValueByNumber(e_ref)->name() + << " without corresponding base enum with value " << e; + ASSERT_LT(DataType_MAX, e_ref) << "Gap in reference types, missing value for " + << e_ref; + + // Make sure there are no enums defined after the last regular type before + // the first reference type. + for (; e < DataType_MIN + kDataTypeRefOffset; ++e) { + EXPECT_FALSE(DataType_IsValid(e)) + << "Discontinuous enum value " + << enum_descriptor->FindValueByNumber(e)->name() << " = " << e; + } +} + +TEST(TypesTest, DataTypeFromString) { + DataType dt; + ASSERT_TRUE(DataTypeFromString("int32", &dt)); + EXPECT_EQ(DT_INT32, dt); + ASSERT_TRUE(DataTypeFromString("int32_ref", &dt)); + EXPECT_EQ(DT_INT32_REF, dt); + EXPECT_FALSE(DataTypeFromString("int32_ref_ref", &dt)); + EXPECT_FALSE(DataTypeFromString("foo", &dt)); + EXPECT_FALSE(DataTypeFromString("foo_ref", &dt)); + ASSERT_TRUE(DataTypeFromString("int64", &dt)); + EXPECT_EQ(DT_INT64, dt); + ASSERT_TRUE(DataTypeFromString("int64_ref", &dt)); + EXPECT_EQ(DT_INT64_REF, dt); + ASSERT_TRUE(DataTypeFromString("quint8_ref", &dt)); + EXPECT_EQ(DT_QUINT8_REF, dt); + ASSERT_TRUE(DataTypeFromString("bfloat16", &dt)); + EXPECT_EQ(DT_BFLOAT16, dt); +} + +template +static bool GetQuantized() { + return is_quantized::value; +} + +TEST(TypesTest, QuantizedTypes) { + // NOTE: GUnit cannot parse is::quantized::value() within the + // EXPECT_TRUE() clause, so we delegate through a template function. + EXPECT_TRUE(GetQuantized()); + EXPECT_TRUE(GetQuantized()); + EXPECT_TRUE(GetQuantized()); + + EXPECT_FALSE(GetQuantized()); + EXPECT_FALSE(GetQuantized()); + EXPECT_FALSE(GetQuantized()); + EXPECT_FALSE(GetQuantized()); + + EXPECT_TRUE(DataTypeIsQuantized(DT_QINT8)); + EXPECT_TRUE(DataTypeIsQuantized(DT_QUINT8)); + EXPECT_TRUE(DataTypeIsQuantized(DT_QINT32)); + + EXPECT_FALSE(DataTypeIsQuantized(DT_INT8)); + EXPECT_FALSE(DataTypeIsQuantized(DT_UINT8)); + EXPECT_FALSE(DataTypeIsQuantized(DT_INT16)); + EXPECT_FALSE(DataTypeIsQuantized(DT_INT32)); + EXPECT_FALSE(DataTypeIsQuantized(DT_BFLOAT16)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc new file mode 100644 index 0000000000..fd79ead0b1 --- /dev/null +++ b/tensorflow/core/graph/algorithm.cc @@ -0,0 +1,107 @@ +#include "tensorflow/core/graph/algorithm.h" + +#include +#include +#include + +namespace tensorflow { + +void DFS(const Graph& g, std::function enter, + std::function leave) { + // Stack of work to do. + struct Work { + Node* node; + bool leave; // Are we entering or leaving n? + }; + std::vector stack; + stack.push_back(Work{g.source_node(), false}); + + std::vector visited(g.num_node_ids(), false); + while (!stack.empty()) { + Work w = stack.back(); + stack.pop_back(); + + Node* n = w.node; + if (w.leave) { + leave(n); + continue; + } + + if (visited[n->id()]) continue; + visited[n->id()] = true; + if (enter) enter(n); + + // Arrange to call leave(n) when all done with descendants. + if (leave) stack.push_back(Work{n, true}); + + // Arrange to work on descendants. + for (Node* out : n->out_nodes()) { + if (!visited[out->id()]) { + // Note; we must not mark as visited until we actually process it. + stack.push_back(Work{out, false}); + } + } + } +} + +void GetPostOrder(const Graph& g, std::vector* order) { + order->clear(); + DFS(g, nullptr, [order](Node* n) { order->push_back(n); }); +} + +void GetReversePostOrder(const Graph& g, std::vector* order) { + GetPostOrder(g, order); + std::reverse(order->begin(), order->end()); +} + +void PruneForReverseReachability(Graph* g, + const std::unordered_set& nodes) { + std::unordered_set visited; + + // Compute set of nodes that we need to traverse in order to reach + // the nodes in "nodes" by performing a breadth-first search from those + // nodes, and accumulating the visited nodes. + std::deque queue; + for (const Node* n : nodes) { + queue.push_back(n); + } + while (!queue.empty()) { + const Node* n = queue.front(); + queue.pop_front(); + if (visited.insert(n).second) { + for (const Node* in : n->in_nodes()) { + queue.push_back(in); + } + } + } + + // Make a pass over the graph to remove nodes not in "visited" + std::vector all_nodes; + for (Node* n : g->nodes()) { + all_nodes.push_back(n); + } + + for (Node* n : all_nodes) { + if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) { + g->RemoveNode(n); + } + } + + // Reconnect nodes with no outgoing edges to the sink node + FixupSourceAndSinkEdges(g); +} + +void FixupSourceAndSinkEdges(Graph* g) { + // Connect all nodes with no incoming edges to source. + // Connect all nodes with no outgoing edges to sink. + for (Node* n : g->nodes()) { + if (!n->IsSource() && n->in_edges().empty()) { + g->AddControlEdge(g->source_node(), n); + } + if (!n->IsSink() && n->out_edges().empty()) { + g->AddControlEdge(n, g->sink_node()); + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h new file mode 100644 index 0000000000..58b74a0ace --- /dev/null +++ b/tensorflow/core/graph/algorithm.h @@ -0,0 +1,40 @@ +#ifndef TENSORFLOW_GRAPH_ALGORITHM_H_ +#define TENSORFLOW_GRAPH_ALGORITHM_H_ + +#include +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Perform a depth-first-search on g starting at the source node. +// If enter is not empty, calls enter(n) before visiting any children of n. +// If leave is not empty, calls leave(n) after visiting all children of n. +extern void DFS(const Graph& g, std::function enter, + std::function leave); + +// Stores in *order the post-order numbering of all nodes +// in graph found via a depth first search starting at the source node. +// +// Note that this is equivalent to topological sorting when the +// graph does not have cycles. +// +// REQUIRES: order is not NULL. +void GetPostOrder(const Graph& g, std::vector* order); + +// Stores in *order the reverse post-order numbering of all nodes +void GetReversePostOrder(const Graph& g, std::vector* order); + +// Prune nodes in "g" that are not in some path from the source node +// to any node in 'nodes'. +void PruneForReverseReachability(Graph* g, + const std::unordered_set& nodes); + +// Connect all nodes with no incoming edges to source. +// Connect all nodes with no outgoing edges to sink. +void FixupSourceAndSinkEdges(Graph* g); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_ALGORITHM_H_ diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc new file mode 100644 index 0000000000..48f2e1ebd7 --- /dev/null +++ b/tensorflow/core/graph/algorithm_test.cc @@ -0,0 +1,103 @@ +#include "tensorflow/core/graph/algorithm.h" + +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/kernels/ops_util.h" +#include +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +// TODO(josh11b): Test setting the "device" field of a NodeDef. +// TODO(josh11b): Test that feeding won't prune targets. + +namespace tensorflow { +namespace { + +REGISTER_OP("TestParams").Output("o: float"); +REGISTER_OP("TestInput").Output("a: float").Output("b: float"); +REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); + +// Compares that the order of nodes in 'inputs' respects the +// pair orders described in 'ordered_pairs'. +bool ExpectBefore(const std::vector>& ordered_pairs, + const std::vector& inputs, string* error) { + for (const std::pair& pair : ordered_pairs) { + const string& before_node = pair.first; + const string& after_node = pair.second; + bool seen_before = false; + bool seen_both = false; + for (const Node* node : inputs) { + if (!seen_before && after_node == node->name()) { + *error = strings::StrCat("Saw ", after_node, " before ", before_node); + return false; + } + + if (before_node == node->name()) { + seen_before = true; + } else if (after_node == node->name()) { + seen_both = seen_before; + break; + } + } + if (!seen_both) { + *error = strings::StrCat("didn't see either ", before_node, " or ", + after_node); + return false; + } + } + + return true; +} + +TEST(AlgorithmTest, ReversePostOrder) { + RequireDefaultOps(); + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* w1 = SourceOp("TestParams", b.opts().WithName("W1")); + Node* w2 = SourceOp("TestParams", b.opts().WithName("W2")); + Node* input = + SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1)); + Node* t1 = BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t1")); + BinaryOp("TestMul", w1, {input, 1}, + b.opts().WithName("t2").WithControlInput(t1)); + BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3")); + + Graph g(OpRegistry::Global()); + ASSERT_OK(b.ToGraph(&g)); + std::vector order; + + // Test reverse post order: + GetReversePostOrder(g, &order); + + // Check that the order respects the dependencies correctly. + std::vector> reverse_orders = { + {"W1", "input"}, {"W1", "t1"}, {"W1", "t2"}, {"W1", "t3"}, + {"input", "t1"}, {"input", "t3"}, {"t1", "t2"}, {"W2", "t3"}}; + string error; + EXPECT_TRUE(ExpectBefore(reverse_orders, order, &error)) << error; + + // A false ordering should fail the check. + reverse_orders = {{"input", "W1"}}; + EXPECT_FALSE(ExpectBefore(reverse_orders, order, &error)); + + // Test post order: + GetPostOrder(g, &order); + + // Check that the order respects the dependencies correctly. + std::vector> orders = { + {"input", "W1"}, {"t1", "W1"}, {"t2", "W1"}, {"t3", "W1"}, + {"t1", "input"}, {"t3", "input"}, {"t2", "t1"}, {"t3", "W2"}}; + EXPECT_TRUE(ExpectBefore(orders, order, &error)) << error; + + // A false ordering should fail the check. + orders = {{"W1", "t3"}}; + EXPECT_FALSE(ExpectBefore(orders, order, &error)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/colors.cc b/tensorflow/core/graph/colors.cc new file mode 100644 index 0000000000..0eb2fc3740 --- /dev/null +++ b/tensorflow/core/graph/colors.cc @@ -0,0 +1,25 @@ +#include "tensorflow/core/graph/colors.h" + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Color palette +// http://www.mulinblog.com/a-color-palette-optimized-for-data-visualization/ +static const char* kColors[] = { + "#F15854", // red + "#5DA5DA", // blue + "#FAA43A", // orange + "#60BD68", // green + "#F17CB0", // pink + "#B2912F", // brown + "#B276B2", // purple + "#DECF3F", // yellow + "#4D4D4D", // gray +}; + +const char* ColorFor(int dindex) { + return kColors[dindex % TF_ARRAYSIZE(kColors)]; +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/colors.h b/tensorflow/core/graph/colors.h new file mode 100644 index 0000000000..150c8dc025 --- /dev/null +++ b/tensorflow/core/graph/colors.h @@ -0,0 +1,14 @@ +#ifndef TENSORFLOW_GRAPH_COLORS_H_ +#define TENSORFLOW_GRAPH_COLORS_H_ + +namespace tensorflow { + +// Return a color drawn from a palette to represent an entity +// identified by "i". The return value has the form "#RRGGBB" Note +// that the palette has a limited set of colors and therefore colors +// will be reused eventually. +const char* ColorFor(int dindex); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_COLORS_H_ diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc new file mode 100644 index 0000000000..89bc41acfd --- /dev/null +++ b/tensorflow/core/graph/costmodel.cc @@ -0,0 +1,308 @@ +#include "tensorflow/core/graph/costmodel.h" + +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace { +const Microseconds kDefaultTimeEstimate(1); +const Microseconds kMinTimeEstimate(1); +} // namespace + +void CostModel::SuppressInfrequent() { + // Find the median of the non-zero counts, and use half of its value + // as the cutoff for a "normal" execution mode node. + if (count_.empty()) return; + std::vector non_zero; + for (auto v : count_) { + if (v > 0) non_zero.push_back(v); + } + const size_t sz = non_zero.size(); + if (sz > 0) { + std::nth_element(non_zero.begin(), non_zero.begin() + sz / 2, + non_zero.end()); + int32 median_value = non_zero[sz / 2]; + min_count_ = median_value / 2; + VLOG(1) << "num non_zero vals: " << non_zero.size() << " median_value " + << median_value; + } else { + min_count_ = 1; + } +} + +void CostModel::MergeFromLocal(const Graph& g, const CostModel& cm) { + CHECK(is_global_); + CHECK(!cm.is_global()); + for (const Node* n : g.nodes()) { + const int local_id = cm.Id(n); + const int global_id = Id(n); + if (local_id < 0 || global_id < 0) continue; + Ensure(global_id); + count_[global_id] += cm.count_[local_id]; + time_[global_id] += cm.time_[local_id]; + int num_slots = cm.slot_bytes_[local_id].size(); + if (num_slots > 0) { + if (slot_bytes_[global_id].size() == 0) { + slot_bytes_[global_id].resize(num_slots); + } else { + CHECK_EQ(num_slots, slot_bytes_[global_id].size()); + } + for (int s = 0; s < num_slots; ++s) { + slot_bytes_[global_id][s] += cm.slot_bytes_[local_id][s]; + } + } + } +} + +void CostModel::MergeFromGlobal(const CostModel& cm) { + CHECK(is_global_); + CHECK_EQ(true, cm.is_global()); + const int num_nodes = cm.count_.size(); + Ensure(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + count_[i] += cm.count_[i]; + time_[i] += cm.time_[i]; + int num_slots = cm.slot_bytes_[i].size(); + if (num_slots > 0) { + if (slot_bytes_[i].size() == 0) { + slot_bytes_[i].resize(num_slots); + } else { + CHECK_EQ(num_slots, slot_bytes_[i].size()); + } + for (int s = 0; s < num_slots; ++s) { + slot_bytes_[i][s] += cm.slot_bytes_[i][s]; + } + } + } +} + +void CostModel::MergeFromStats(const NodeNameToCostIdMap& map, + const StepStats& ss) { + CHECK(is_global_); + for (auto& ds : ss.dev_stats()) { + for (auto& ns : ds.node_stats()) { + NodeNameToCostIdMap::const_iterator iter = map.find(ns.node_name()); + // We don't keep stats for nodes not in the global graph, i.e. + // copy/send/recv nodes, feed/fetch, etc. + if (iter == map.end()) continue; + int32 global_id = iter->second; + Ensure(global_id); + int64 elapsed_micros = ns.op_end_rel_micros() - ns.op_start_rel_micros(); + count_[global_id]++; + time_[global_id] += elapsed_micros; + for (auto& no : ns.output()) { + int si = no.slot(); + if (static_cast(si) >= slot_bytes_[global_id].size()) { + slot_bytes_[global_id].resize(1 + si); + } + slot_bytes_[global_id][si] += + no.tensor_description().allocation_description().requested_bytes(); + } + } + } +} + +void CostModel::Ensure(int id) { + if (slot_bytes_.size() <= static_cast(id)) { + slot_bytes_.resize(id + 1); + count_.resize(id + 1); + time_.resize(id + 1); + } +} + +void CostModel::SetNumOutputs(const Node* node, int num_outputs) { + const int id = Id(node); + if (id < 0) return; + Ensure(id); + auto perslot = &slot_bytes_[id]; + if (perslot->size() > 0) { + CHECK_EQ(num_outputs, perslot->size()) << "Cannot resize slot_bytes, node=" + << node->name(); + } else { + perslot->resize(num_outputs, Bytes(-1)); + } +} + +void CostModel::RecordCount(const Node* node, int count) { + const int id = Id(node); + if (id < 0) return; + CHECK_LT(id, slot_bytes_.size()); + count_[id] += count; +} + +int32 CostModel::TotalCount(const Node* node) const { + const int id = Id(node); + if (id < 0) return 0; + return (static_cast(id) < slot_bytes_.size()) ? count_[id] : 0; +} + +void CostModel::RecordSize(const Node* node, int slot, Bytes bytes) { + const int id = Id(node); + if (id < 0) return; + CHECK_LT(id, slot_bytes_.size()); + auto perslot = &slot_bytes_[id]; + CHECK_LT(slot, perslot->size()); + auto v = &(*perslot)[slot]; + if (*v >= 0) { + *v += bytes; + } else { + *v = bytes; + } +} + +Bytes CostModel::TotalBytes(const Node* node, int slot) const { + const int id = Id(node); + if (id < 0 || static_cast(id) >= slot_bytes_.size() || + slot_bytes_[id].size() <= static_cast(slot)) { + return Bytes(0); + } + return slot_bytes_[id][slot]; +} + +Bytes CostModel::SizeEstimate(const Node* node, int slot) const { + int32 count = TotalCount(node); + if (count < min_count_) return Bytes(0); + return TotalBytes(node, slot) / std::max(1, TotalCount(node)); +} + +void CostModel::RecordTime(const Node* node, Microseconds time) { + const int id = Id(node); + if (id < 0) return; + DCHECK(node->IsOp()) << node->DebugString(); + Ensure(id); + time_[id] += time; +} + +Microseconds CostModel::TotalTime(const Node* node) const { + DCHECK(node->IsOp()) << node->DebugString(); + const int id = Id(node); + if (id < 0 || static_cast(id) >= time_.size() || + time_[id] < Microseconds(0)) { + return Microseconds(0); + } + return time_[id]; +} + +Microseconds CostModel::TimeEstimate(const Node* node) const { + int32 count = TotalCount(node); + if (count <= min_count_) return kMinTimeEstimate; + return std::max(kMinTimeEstimate, TotalTime(node) / std::max(1, count)); +} + +void CostModel::CheckInitialized(const Graph& graph) const { + for (const Node* n : graph.nodes()) { + if (n->IsOp()) { + CHECK(static_cast(n->id()) < time_.size() && + time_[n->id()] >= Microseconds(0)) + << ": no time estimate for " << n->DebugString(); + + CHECK(static_cast(n->id()) < slot_bytes_.size()) + << ": no size estimate for " << n->DebugString(); + const auto& perslot = slot_bytes_[n->id()]; + for (size_t i = 0; i < perslot.size(); i++) { + CHECK_GE(perslot[i], Bytes(0)) << ": no size estimate for output# " << i + << " of " << n->DebugString(); + } + } + } +} + +Microseconds CostModel::CopyTimeEstimate(Bytes b, double network_latency_millis, + double estimated_gbps) { + // TODO(jeff,sanjay): estimate cost based on bandwidth along the + // communication path and the type of transport we are using between + // devices. + // + // We assume the copy time follows a linear model: + // copy_time = copy_bytes / rate + min_time + int64 copy_bytes = b.value(); + const double bytes_per_usec = estimated_gbps * 1000.0 / 8; + const double min_micros = network_latency_millis * 1000.0; + return Microseconds( + static_cast(copy_bytes / bytes_per_usec + min_micros)); +} + +Microseconds CostModel::ComputationTimeEstimate(int64 math_ops) { + // TODO(jeff,sanjay): Eventually we should pass in the type of device + // (GPU vs. CPU) and use that to affect the estimate. + + // We estimate the microseconds using that value. We divide + // by 1000 to convert the madd number into microseconds (assuming + // roughly 1000 madds per microsecond (~1 GHz for one core)). + return Microseconds(math_ops / 1000); +} + +// ---------------------------------------------------------------------------- +// InitCostModel +// ---------------------------------------------------------------------------- + +namespace { + +static void AddNodesToCostModel(const Graph& g, CostModel* cost_model) { + for (Node* n : g.nodes()) { + const int num_outputs = n->num_outputs(); + cost_model->SetNumOutputs(n, num_outputs); + for (int output = 0; output < num_outputs; output++) { + // Set up an initial bogus estimate for the node's outputs + cost_model->RecordSize(n, output, Bytes(1)); + } + } +} + +static void AssignSizes(const Graph& g, CostModel* cost_model) { + for (const Edge* e : g.edges()) { + // Skip if it is a control edge. + if (e->IsControlEdge()) { + continue; + } + Node* src = e->src(); + + // TODO(josh11b): Get an estimate from the Op + Bytes size(1); + cost_model->RecordSize(src, e->src_output(), size); + } +} + +// This generates an extremely simple initial guess for the +// computation cost of each node. For ordinary Ops, its value should quickly +// be wiped out by the real runtime measurements. For other Ops we don't +// actually generate measurements, so suppression of infrequent Ops ends up +// giving them 0 costs. So, this is not of much consequence except perhaps +// in tests. +static Microseconds TimeEstimateForNode(CostModel* cost_model, Node* n) { + CHECK(n->IsOp()); + VLOG(2) << "Node " << n->id() << ": " << n->name() + << " type_string: " << n->type_string(); + if (IsConstant(n) || IsVariable(n)) { + return Microseconds(0); + } + return kDefaultTimeEstimate; +} + +static void EstimateComputationCosts(const Graph& g, CostModel* cost_model) { + for (Node* n : g.nodes()) { + if (!n->IsOp()) continue; + cost_model->RecordTime(n, TimeEstimateForNode(cost_model, n)); + } +} + +} // namespace + +void CostModel::InitFromGraph(const Graph& g) { + AddNodesToCostModel(g, this); + AssignSizes(g, this); + EstimateComputationCosts(g, this); + CheckInitialized(g); +} + +void CostModel::WriteToLog() { + LOG(INFO) << " min_count_=" << min_count_; + for (size_t i = 0; i < count_.size(); ++i) { + LOG(INFO) << "Node " << i << " count " << count_[i] << " total time " + << time_[i] << " avg time " + << (time_[i] / (std::max(1, count_[i]))); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h new file mode 100644 index 0000000000..4d7dd65f5a --- /dev/null +++ b/tensorflow/core/graph/costmodel.h @@ -0,0 +1,123 @@ +#ifndef TENSORFLOW_GRAPH_COSTMODEL_H_ +#define TENSORFLOW_GRAPH_COSTMODEL_H_ + +#include +#include + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { +typedef std::unordered_map NodeNameToCostIdMap; + +class StepStats; + +// CostModel keeps track of the following runtime statistics for nodes +// of a single Graph: +// * The total number of times a node has executed. +// * The accumulated execution time (in microseconds) of a node. +// * The accumulated size (in bytes) of each node's output. +// +// This class is NOT thread-safe. +class CostModel { + public: + // If "global" is true, maintains costs based on Node::cost_id, otherwise + // maintains costs based on Node::id. + explicit CostModel(bool is_global) : is_global_(is_global) {} + + // Assigns min_count_ as a function of the median count for a Node. + // This value is then used for suppressing the time/size costs of + // infrequent operations. + // NOTE(tucker): Maybe this should move to a subclass of CostModel. + void SuppressInfrequent(); + + bool is_global() const { return is_global_; } + + // Initializes cost model for 'g'. + void InitFromGraph(const Graph& g); + + // Merges costs from cm. + // REQUIRES: is_global_ is true for this and for "cm" + void MergeFromGlobal(const CostModel& cm); + + // Merges costs from "cm", which has been computed relative to "g". + // REQUIRES: is_global_ is true for this, and false for "cm". + void MergeFromLocal(const Graph& g, const CostModel& cm); + + void MergeFromStats(const NodeNameToCostIdMap& map, const StepStats& ss); + + // Sets the number of outputs of "node". + void SetNumOutputs(const Node* node, int num_outputs); + + // Records that "node" has executed "num_count" more times. + void RecordCount(const Node* node, int num_count); + + // Returns how many times "node" has been executed. + int32 TotalCount(const Node* node) const; + + // Records that "output_slot" of "node" has produced tensors of + // aggregated "bytes". + void RecordSize(const Node* node, int output_slot, Bytes bytes); + + // Returns total bytes of tensors produced by "node"s output slot. + Bytes TotalBytes(const Node* node, int output_slot) const; + + // Returns a prediction for the size of the tensor at the + // output_slot produced by one execution of "node". + Bytes SizeEstimate(const Node* node, int output_slot) const; + + // Records that Executions of "node" have taken "time" microseconds. + void RecordTime(const Node* node, Microseconds time); + + // Returns the total execution time for "node". + Microseconds TotalTime(const Node* node) const; + + // Returns a prediction for one execution of "node". + Microseconds TimeEstimate(const Node* node) const; + + // Check that an estimate is available for every OP node in graph. + void CheckInitialized(const Graph& graph) const; + + // Helper routines to encapsulate static estimatation heuristics + + // Compute an estimate of the time to copy "b" bytes over the network, + // given a fixed cost of "network_latency_millis" milliseconds and + // an estimated bandwidth of "estimated_gbps" gigabits per second (note that + // this value is in gigabits, not gigabytes). + static Microseconds CopyTimeEstimate(Bytes b, double network_latency_millis, + double estimated_gbps); + static Microseconds ComputationTimeEstimate(int64 mathops); + + // Write the contents of the CostModel to the INFO log. + void WriteToLog(); + + private: + const bool is_global_; + inline int Id(const Node* n) const { + if (is_global_) { + return n->cost_id(); + } else { + return n->id(); + } + } + // Resizes vectors so that they are large enough for "id". + void Ensure(int id); + + // Nodes and Edges whose count is < this value + // get type/byte estimates of 0. + int32 min_count_ = 0; + + // Number of times each Node has been executed. + std::vector count_; + // Cumulative execution time. + std::vector time_; + // Cumulative Bytes output on each channel. + std::vector > slot_bytes_; + + TF_DISALLOW_COPY_AND_ASSIGN(CostModel); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_COSTMODEL_H_ diff --git a/tensorflow/core/graph/costutil.cc b/tensorflow/core/graph/costutil.cc new file mode 100644 index 0000000000..f8e2d9fe68 --- /dev/null +++ b/tensorflow/core/graph/costutil.cc @@ -0,0 +1,22 @@ +#include "tensorflow/core/graph/costutil.h" + +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/costmodel.h" + +namespace tensorflow { + +std::vector LongestOutgoingPathCost(const Graph& graph, + const CostModel& cm) { + std::vector result(graph.num_node_ids()); + DFS(graph, nullptr, [&result, &cm](Node* n) { + int64 max_child = 0; + for (const Node* out : n->out_nodes()) { + max_child = std::max(max_child, result[out->id()]); + } + result[n->id()] = max_child + (n->IsOp() ? cm.TimeEstimate(n).value() : 0); + }); + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/costutil.h b/tensorflow/core/graph/costutil.h new file mode 100644 index 0000000000..46e5215132 --- /dev/null +++ b/tensorflow/core/graph/costutil.h @@ -0,0 +1,19 @@ +#ifndef TENSORFLOW_GRAPH_COSTUTIL_H_ +#define TENSORFLOW_GRAPH_COSTUTIL_H_ + +#include +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class CostModel; +class Graph; + +// result[i] is an estimate of the longest execution path from +// the node with id i to the sink node. +std::vector LongestOutgoingPathCost(const Graph& graph, + const CostModel& cm); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_COSTUTIL_H_ diff --git a/tensorflow/core/graph/default_device.h b/tensorflow/core/graph/default_device.h new file mode 100644 index 0000000000..30cd4e8a57 --- /dev/null +++ b/tensorflow/core/graph/default_device.h @@ -0,0 +1,25 @@ +#ifndef TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_ +#define TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_ + +#include + +#include "tensorflow/core/framework/graph.pb.h" + +namespace tensorflow { +namespace graph { + +// Sets the default device for all nodes in graph_def to "device", +// only if not already set. +inline void SetDefaultDevice(const string& device, GraphDef* graph_def) { + for (int i = 0; i < graph_def->node_size(); ++i) { + auto node = graph_def->mutable_node(i); + if (node->device().empty()) { + node->set_device(device); + } + } +} + +} // namespace graph +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_ diff --git a/tensorflow/core/graph/dot.cc b/tensorflow/core/graph/dot.cc new file mode 100644 index 0000000000..6d6e46ce61 --- /dev/null +++ b/tensorflow/core/graph/dot.cc @@ -0,0 +1,289 @@ +#include "tensorflow/core/graph/dot.h" + +#include +#include +#include + +#include "tensorflow/core/graph/colors.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { + +static string GraphNodeName(const DotOptions& opts, const Node* n) { + return strings::StrCat("N", n->id()); +} + +bool ShoulDisplayOpType(const Node* n) { + if (n->type_string() == "NoOp") { + return false; + } + const string& op_name = n->def().name(); + if (op_name.find(n->type_string() + "_") == 0) { + return false; + } + return true; +} + +string DotGraph(const Graph& g, const DotOptions& opts) { + RegexpStringPiece flag(opts.prefix_collapse_regexp); + if (flag == "all") { + flag = "."; + } else if (flag == "none") { + flag = "^$"; + } + RE2 cluster_name_pattern(flag); + string result; + strings::StrAppend(&result, "digraph G {\n"); + strings::StrAppend(&result, "rankdir=\"BT\"\n"); + + std::map device_index; // Map from device name to index. + std::unordered_set visible_nodes; // Nodes to display. + // Cluster name => set of nodes. + std::unordered_map > clusters; + // Node* => Cluster + std::unordered_map node_cluster; + for (Node* src : g.nodes()) { + if (opts.include_node_function != nullptr && + !opts.include_node_function(src)) { + continue; + } + // Do not display source and sink nodes + if (src->IsSource() || src->IsSink()) { + continue; + } + visible_nodes.insert(src); + const string name_prefix = NodeNamePrefix(src->def().name()).ToString(); + if (!name_prefix.empty()) { + clusters[name_prefix].insert(src); + node_cluster[src] = name_prefix; + } + // Record device if present. + if (src->IsOp()) { + const string& d = src->assigned_device_name(); + if (!d.empty()) { + device_index[d] = -1; // Assigned later + } + } + } + + // Add nodes whose name is exactly a cluster name to the cluster itself. + for (Node* src : g.nodes()) { + if (node_cluster.count(src) == 0) { + const string name = src->def().name(); + auto it = clusters.find(name); + if (it != clusters.end()) { + it->second.insert(src); + node_cluster[src] = name; + } + } + } + + auto node_in_collapsed_cluster = [&node_cluster, + &cluster_name_pattern](Node* n) { + return node_cluster.count(n) > 0 && + RE2::PartialMatch(node_cluster[n], cluster_name_pattern); + }; + + // Assign device indices in sorted order. + int num = 0; + for (auto& e : device_index) { + e.second = num++; + } + + double total_node_cost = 0; + double avg_node_cost = 1; + if (opts.node_cost) { + int node_count = 0; + for (const Node* n : g.nodes()) { + total_node_cost += opts.node_cost(n); + ++node_count; + } + if (total_node_cost > 0) avg_node_cost = total_node_cost / node_count; + } + + for (Node* src : g.nodes()) { + if (visible_nodes.count(src) == 0 || node_in_collapsed_cluster(src)) { + continue; + } + string label = src->name(); + if (ShoulDisplayOpType(src)) { + // Append the op type if it is not directly deducible from the op name. + strings::StrAppend(&label, "\\n(", src->type_string(), ")"); + } + const char* shape = "box"; + const char* color = nullptr; + if (src->IsSource()) { + shape = "oval"; + } else if (src->IsSink()) { + shape = "oval"; + } else { + const string& d = src->assigned_device_name(); + const int dindex = (!d.empty()) ? device_index[d] : -1; + if (dindex >= 0) { + color = ColorFor(dindex); + } + + shape = "box"; + } + + if (opts.node_label) { + string extra = opts.node_label(src); + if (!extra.empty()) { + strings::StrAppend(&label, "\\n", extra); + } + } + + strings::StrAppend(&result, GraphNodeName(opts, src), "[shape=", shape, + ", label=\"", label, "\""); + if (opts.node_cost && total_node_cost > 0) { + // Pick fontsize in range [8..40] so that area is proportional to cost. + const double cost = opts.node_cost(src); + const double relcost = fabs(cost / avg_node_cost); + // Average cost node has font size of 12. + const int fs = 8 + static_cast(4.0 * std::min(sqrt(relcost), 8.0)); + strings::StrAppend(&result, ", width=0, height=0, fontsize=", fs); + VLOG(2) << "Node: " << cost << " => " << relcost << " => " << fs; + } + if (color != nullptr) { + strings::StrAppend(&result, ", fillcolor=\"", color, + "\", fontcolor=\"white\", style=\"filled\""); + } + strings::StrAppend(&result, "]\n"); + } + + for (auto c : clusters) { + const string& cluster_name = c.first; + const std::unordered_set nodes = c.second; + std::unordered_map node_colors; + for (auto n : nodes) { + const string& d = n->assigned_device_name(); + const int dindex = (!d.empty()) ? device_index[d] : -1; + if (dindex >= 0) { + ++node_colors[ColorFor(dindex)]; + } + } + + string majority_color; + if (node_colors.empty()) { + majority_color = ColorFor(0); + } else { + majority_color = std::max_element(node_colors.begin(), node_colors.end(), + [](const std::pair& x, + const std::pair& y) { + return x.second < y.second; + }) + ->first; + } + + if (!RE2::PartialMatch(cluster_name, cluster_name_pattern)) { + strings::StrAppend(&result, "subgraph cluster_", cluster_name, "{\n"); + for (auto n : nodes) { + strings::StrAppend(&result, GraphNodeName(opts, n), ";\n"); + } + strings::StrAppend(&result, "}\n"); + } else { + strings::StrAppend(&result, cluster_name, " [shape=oval, fillcolor=\"", + majority_color, "\", label=\"", cluster_name, + "\", style=\"filled\", fontcolor=\"white\"]\n"); + } + } + + std::unordered_set edge_drawn; + + double max_edge_cost = 0; + double total_edge_cost = 0; + double avg_edge_cost = 1; + if (opts.edge_cost && g.edges().size()) { + for (const Edge* e : g.edges()) { + auto cost = opts.edge_cost(e); + total_edge_cost += cost; + max_edge_cost = std::max(max_edge_cost, cost); + } + avg_edge_cost = total_edge_cost / g.edges().size(); + } + VLOG(2) << "Edge cost tot/max/avg: " << total_edge_cost << "/" + << max_edge_cost << "/" << avg_edge_cost; + + for (const Edge* e : g.edges()) { + Node* src = e->src(); + Node* dst = e->dst(); + // If either endpoint isn't drawn in the graph, don't draw the edge + if (visible_nodes.count(src) == 0 || visible_nodes.count(dst) == 0) { + continue; + } + + const string src_name = node_in_collapsed_cluster(src) + ? node_cluster[src] + : GraphNodeName(opts, src); + const string dst_name = node_in_collapsed_cluster(dst) + ? node_cluster[dst] + : GraphNodeName(opts, dst); + // Don't draw self edges + if (src_name == dst_name) { + continue; + } + // And previously drawn edges. + const string& edge_name = strings::StrCat(src_name, ":", dst_name); + if (edge_drawn.count(edge_name) > 0) { + continue; + } + edge_drawn.insert(edge_name); + + strings::StrAppend(&result, src_name, " -> ", dst_name, "["); + string label; + if (e->IsControlEdge()) { + strings::StrAppend(&result, " style=dotted"); + } + if (opts.edge_label) { + string label = opts.edge_label(e); + if (!label.empty()) { + strings::StrAppend(&result, " label=<", label, ">"); + } + } + // Make edge widths proportional to amount of data transferred. + if (opts.edge_cost && max_edge_cost > 0) { + const double cost = opts.edge_cost(e); + const double relcost = fabs(cost / avg_edge_cost); + // Pick penwidth in range [1..6] so that width is proportional to cost. + const int pw = 1 + std::min(5, static_cast(2.0 * relcost)); + strings::StrAppend(&result, " penwidth=", pw); + // Use weight attributes [1..100] to keep heavier edges more vertical. + const int weight = 1 + std::min(99, static_cast(100.0 * relcost)); + strings::StrAppend(&result, " weight=", weight); + VLOG(2) << "Edge: " << cost << " => " << relcost << " => " << pw << "/" + << weight; + } + + strings::StrAppend(&result, "]\n"); + } + // Compute some statistics + int op_nodes = 0; + for (Node* n : g.nodes()) { + if (n->IsOp()) { + op_nodes++; + } + } + + // Emit legend + strings::StrAppend(&result, + "{ rank = source; Legend [shape=box, margin=0, label=<", + "", "\n"); + for (const auto& e : device_index) { + const int dindex = e.second; + strings::StrAppend(&result, "\n"); + } + strings::StrAppend(&result, "
op_nodes: ", + op_nodes, "
", dindex, "", + e.first, "
>]}\n"); + + strings::StrAppend(&result, "}\n"); // End digraph + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/dot.h b/tensorflow/core/graph/dot.h new file mode 100644 index 0000000000..f87f68099c --- /dev/null +++ b/tensorflow/core/graph/dot.h @@ -0,0 +1,43 @@ +#ifndef TENSORFLOW_GRAPH_DOT_H_ +#define TENSORFLOW_GRAPH_DOT_H_ + +#include +#include +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class Edge; +class Graph; +class Node; + +struct DotOptions { + bool (*include_node_function)(const Node*) = nullptr; + + // By default, all nodes with the same name prefix are collapsed into + // a single node in the dot graph. This regexp can be changed so that + // only prefixes that match the regexp are collapsed in this fashion. + // 'all' collapses all ops with prefixes, 'none' disables all collapsing. + string prefix_collapse_regexp = "all"; + + // A function that returns a label to embed into the per-node display. + std::function node_label; + + // A function that returns a label to attach to an edge. + std::function edge_label; + + // A function that returns the "cost" of the node. The dot display + // makes a node size proportional to its cost. + std::function node_cost; + + // A function that returns the "cost" of the edge. The dot display + // makes a edge thickness proportional to its cost. + std::function edge_cost; +}; + +// Return a string that contains a graphviz specification of the graph. +string DotGraph(const Graph& g, const DotOptions& opts); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_DOT_H_ diff --git a/tensorflow/core/graph/edgeset.cc b/tensorflow/core/graph/edgeset.cc new file mode 100644 index 0000000000..83293c7b4e --- /dev/null +++ b/tensorflow/core/graph/edgeset.cc @@ -0,0 +1,56 @@ +#include "tensorflow/core/graph/edgeset.h" + +namespace tensorflow { + +std::pair EdgeSet::insert(value_type value) { + RegisterMutation(); + const_iterator ci; + ci.Init(this); + auto s = get_set(); + if (!s) { + for (int i = 0; i < kInline; i++) { + if (ptrs_[i] == value) { + ci.array_iter_ = &ptrs_[i]; + return std::make_pair(ci, false); + } + } + for (int i = 0; i < kInline; i++) { + if (ptrs_[i] == nullptr) { + ptrs_[i] = value; + ci.array_iter_ = &ptrs_[i]; + return std::make_pair(ci, true); + } + } + // array is full. convert to set. + s = new std::set; + for (int i = 0; i < kInline; i++) { + s->insert(static_cast(ptrs_[i])); + } + ptrs_[0] = this; + ptrs_[1] = s; + // fall through. + } + auto p = s->insert(value); + ci.tree_iter_ = p.first; + return std::make_pair(ci, p.second); +} + +EdgeSet::size_type EdgeSet::erase(key_type key) { + RegisterMutation(); + auto s = get_set(); + if (!s) { + for (int i = 0; i < kInline; i++) { + if (ptrs_[i] == key) { + size_t n = size(); + ptrs_[i] = ptrs_[n - 1]; + ptrs_[n - 1] = nullptr; + return 1; + } + } + return 0; + } else { + return s->erase(key); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/edgeset.h b/tensorflow/core/graph/edgeset.h new file mode 100644 index 0000000000..df0d78b8fb --- /dev/null +++ b/tensorflow/core/graph/edgeset.h @@ -0,0 +1,216 @@ +#ifndef TENSORFLOW_GRAPH_EDGESET_H_ +#define TENSORFLOW_GRAPH_EDGESET_H_ + +#include +#include +#include "tensorflow/core/platform/port.h" + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { + +class Edge; + +// An unordered set of edges. Uses very little memory for small sets. +// Unlike std::set, EdgeSet does NOT allow mutations during iteration. +class EdgeSet { + public: + EdgeSet(); + ~EdgeSet(); + + typedef const Edge* key_type; + typedef const Edge* value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + class const_iterator; + typedef const_iterator iterator; + + bool empty() const; + size_type size() const; + void clear(); + std::pair insert(value_type value); + size_type erase(key_type key); + + // Caller is not allowed to mutate the EdgeSet while iterating. + const_iterator begin() const; + const_iterator end() const; + + private: + // Up to kInline elements are stored directly in ptrs_ (nullptr means none). + // If ptrs_[0] == this then ptrs_[1] points to a set. + static const int kInline = 2; // Must be >= 2. + const void* ptrs_[kInline]; + + std::set* get_set() const { + if (ptrs_[0] == this) { + return static_cast*>(const_cast(ptrs_[1])); + } else { + return nullptr; + } + } + +// To detect mutations while iterating. +#ifdef NDEBUG + void RegisterMutation() {} +#else + uint32 mutations_ = 0; + void RegisterMutation() { mutations_++; } +#endif + + TF_DISALLOW_COPY_AND_ASSIGN(EdgeSet); +}; + +class EdgeSet::const_iterator { + public: + typedef typename EdgeSet::value_type value_type; + typedef const typename EdgeSet::value_type& reference; + typedef const typename EdgeSet::value_type* pointer; + typedef typename EdgeSet::difference_type difference_type; + typedef std::forward_iterator_tag iterator_category; + + const_iterator() {} + + const_iterator& operator++(); + const_iterator operator++(int /*unused*/); + const value_type* operator->() const; + value_type operator*() const; + bool operator==(const const_iterator& other) const; + bool operator!=(const const_iterator& other) const { + return !(*this == other); + } + + private: + friend class EdgeSet; + + void const* const* array_iter_ = nullptr; + typename std::set::const_iterator tree_iter_; + +#ifdef NDEBUG + inline void Init(const EdgeSet* e) {} + inline void CheckNoMutations() const {} +#else + inline void Init(const EdgeSet* e) { + owner_ = e; + init_mutations_ = e->mutations_; + } + inline void CheckNoMutations() const { + CHECK_EQ(init_mutations_, owner_->mutations_); + } + const EdgeSet* owner_ = nullptr; + uint32 init_mutations_ = 0; +#endif +}; + +inline EdgeSet::EdgeSet() { + for (int i = 0; i < kInline; i++) { + ptrs_[i] = nullptr; + } +} + +inline EdgeSet::~EdgeSet() { delete get_set(); } + +inline bool EdgeSet::empty() const { return size() == 0; } + +inline EdgeSet::size_type EdgeSet::size() const { + auto s = get_set(); + if (s) { + return s->size(); + } else { + size_t result = 0; + for (int i = 0; i < kInline; i++) { + if (ptrs_[i]) result++; + } + return result; + } +} + +inline void EdgeSet::clear() { + RegisterMutation(); + delete get_set(); + for (int i = 0; i < kInline; i++) { + ptrs_[i] = nullptr; + } +} + +inline EdgeSet::const_iterator EdgeSet::begin() const { + const_iterator ci; + ci.Init(this); + auto s = get_set(); + if (s) { + ci.tree_iter_ = s->begin(); + } else { + ci.array_iter_ = &ptrs_[0]; + } + return ci; +} + +inline EdgeSet::const_iterator EdgeSet::end() const { + const_iterator ci; + ci.Init(this); + auto s = get_set(); + if (s) { + ci.tree_iter_ = s->end(); + } else { + ci.array_iter_ = &ptrs_[size()]; + } + return ci; +} + +inline EdgeSet::const_iterator& EdgeSet::const_iterator::operator++() { + CheckNoMutations(); + if (array_iter_ != nullptr) { + ++array_iter_; + } else { + ++tree_iter_; + } + return *this; +} + +inline EdgeSet::const_iterator EdgeSet::const_iterator::operator++( + int /*unused*/) { + CheckNoMutations(); + const_iterator tmp = *this; + operator++(); + return tmp; +} + +// gcc's set and multiset always use const_iterator since it will otherwise +// allow modification of keys. +inline const EdgeSet::const_iterator::value_type* EdgeSet::const_iterator:: +operator->() const { + CheckNoMutations(); + if (array_iter_ != nullptr) { + return reinterpret_cast(array_iter_); + } else { + return tree_iter_.operator->(); + } +} + +// gcc's set and multiset always use const_iterator since it will otherwise +// allow modification of keys. +inline EdgeSet::const_iterator::value_type EdgeSet::const_iterator::operator*() + const { + CheckNoMutations(); + if (array_iter_ != nullptr) { + return static_cast(*array_iter_); + } else { + return *tree_iter_; + } +} + +inline bool EdgeSet::const_iterator::operator==( + const const_iterator& other) const { + DCHECK((array_iter_ == nullptr) == (other.array_iter_ == nullptr)) + << "Iterators being compared must be from same set that has not " + << "been modified since the iterator was constructed"; + CheckNoMutations(); + if (array_iter_ != nullptr) { + return array_iter_ == other.array_iter_; + } else { + return other.array_iter_ == nullptr && tree_iter_ == other.tree_iter_; + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_EDGESET_H_ diff --git a/tensorflow/core/graph/edgeset_test.cc b/tensorflow/core/graph/edgeset_test.cc new file mode 100644 index 0000000000..7909e8ea0a --- /dev/null +++ b/tensorflow/core/graph/edgeset_test.cc @@ -0,0 +1,95 @@ +#include "tensorflow/core/graph/edgeset.h" + +#include "tensorflow/core/graph/graph.h" +#include + +namespace tensorflow { +class EdgeSetTest : public ::testing::Test { + public: + EdgeSetTest() : edges_(nullptr), eset_(nullptr) {} + + ~EdgeSetTest() override { + delete eset_; + delete[] edges_; + } + + void MakeEdgeSet(int n) { + delete eset_; + delete[] edges_; + edges_ = new Edge[n]; + eset_ = new EdgeSet; + model_.clear(); + for (int i = 0; i < n; i++) { + eset_->insert(&edges_[i]); + model_.insert(&edges_[i]); + } + } + + void CheckSame() { + EXPECT_EQ(model_.size(), eset_->size()); + EXPECT_EQ(model_.empty(), eset_->empty()); + std::vector modelv(model_.begin(), model_.end()); + std::vector esetv(eset_->begin(), eset_->end()); + std::sort(modelv.begin(), modelv.end()); + std::sort(esetv.begin(), esetv.end()); + EXPECT_EQ(modelv.size(), esetv.size()); + for (size_t i = 0; i < modelv.size(); i++) { + EXPECT_EQ(modelv[i], esetv[i]) << i; + } + } + + Edge nonexistent_; + Edge* edges_; + EdgeSet* eset_; + std::set model_; +}; + +namespace { + +TEST_F(EdgeSetTest, Ops) { + for (int n : {0, 1, 2, 3, 4, 10}) { + MakeEdgeSet(n); + CheckSame(); + EXPECT_EQ((n == 0), eset_->empty()); + EXPECT_EQ(n, eset_->size()); + + eset_->clear(); + model_.clear(); + CheckSame(); + + eset_->insert(&edges_[0]); + model_.insert(&edges_[0]); + CheckSame(); + } +} + +// Try insert/erase of existing elements at different positions. +TEST_F(EdgeSetTest, Exists) { + for (int n : {0, 1, 2, 3, 4, 10}) { + MakeEdgeSet(n); + for (int pos = 0; pos < n; pos++) { + MakeEdgeSet(n); + auto p = eset_->insert(&edges_[pos]); + EXPECT_FALSE(p.second); + EXPECT_EQ(&edges_[pos], *p.first); + + EXPECT_EQ(1, eset_->erase(&edges_[pos])); + model_.erase(&edges_[pos]); + CheckSame(); + } + } +} + +// Try insert/erase of non-existent element. +TEST_F(EdgeSetTest, DoesNotExist) { + for (int n : {0, 1, 2, 3, 4, 10}) { + MakeEdgeSet(n); + EXPECT_EQ(0, eset_->erase(&nonexistent_)); + auto p = eset_->insert(&nonexistent_); + EXPECT_TRUE(p.second); + EXPECT_EQ(&nonexistent_, *p.first); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/equal_graph_def.cc b/tensorflow/core/graph/equal_graph_def.cc new file mode 100644 index 0000000000..35f59b5ed0 --- /dev/null +++ b/tensorflow/core/graph/equal_graph_def.cc @@ -0,0 +1,176 @@ +#include "tensorflow/core/graph/equal_graph_def.h" + +#include +#include +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, + string* diff) { + std::unordered_map actual_index; + for (const NodeDef& node : actual.node()) { + actual_index[node.name()] = &node; + } + + for (const NodeDef& expected_node : expected.node()) { + auto actual_iter = actual_index.find(expected_node.name()); + if (actual_iter == actual_index.end()) { + if (diff != nullptr) { + *diff = strings::StrCat("Did not find expected node '", + SummarizeNodeDef(expected_node), "'"); + } + return false; + } + + if (!EqualNodeDef(*actual_iter->second, expected_node, diff)) return false; + + actual_index.erase(actual_iter); + } + + if (!actual_index.empty()) { + if (diff != nullptr) { + *diff = strings::StrCat("Found unexpected node '", + SummarizeNodeDef(*actual_index.begin()->second), + "' not in expected graph:\n", + SummarizeGraphDef(expected)); + } + return false; + } + + return true; +} + +namespace { + +string JoinStringField(const protobuf::RepeatedPtrField& f) { + string ret; + for (int i = 0; i < f.size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, f.Get(i)); + } + return ret; +} + +} // namespace + +bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, + string* diff) { + if (actual.name() != expected.name()) { + if (diff != nullptr) { + *diff = strings::StrCat("Actual node name '", actual.name(), + "' is not expected '", expected.name(), "'"); + } + return false; + } + + if (actual.op() != expected.op()) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), "' has op '", + actual.op(), "' that is not expected '", + expected.op(), "'"); + } + return false; + } + + if (actual.device() != expected.device()) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), "' has device '", + actual.device(), "' that is not expected '", + expected.device(), "'"); + } + return false; + } + + if (actual.input_size() != expected.input_size()) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), "' has inputs '", + JoinStringField(actual.input()), + "' that don't match expected '", + JoinStringField(expected.input()), "'"); + } + return false; + } + + int first_control_input = actual.input_size(); + for (int i = 0; i < actual.input_size(); ++i) { + if (StringPiece(actual.input(i)).starts_with("^")) { + first_control_input = i; + break; + } + if (actual.input(i) != expected.input(i)) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), "' has input ", + i, " '", actual.input(i), + "' that doesn't match expected '", + expected.input(i), "'"); + } + return false; + } + } + + std::unordered_set actual_control; + std::unordered_set expected_control; + for (int i = first_control_input; i < actual.input_size(); ++i) { + actual_control.insert(actual.input(i)); + expected_control.insert(expected.input(i)); + } + for (const auto& e : expected_control) { + if (actual_control.erase(e) == 0) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), + "' missing expected control input '", e, "'"); + } + return false; + } + } + if (!actual_control.empty()) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), + "' has unexpected control input '", + *actual_control.begin(), "'"); + } + return false; + } + + std::unordered_set actual_attr; + for (const auto& a : actual.attr()) { + actual_attr.insert(a.first); + } + for (const auto& e : expected.attr()) { + if (actual_attr.erase(e.first) == 0) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), + "' missing expected attr '", e.first, + "' with value: ", SummarizeAttrValue(e.second)); + } + return false; + } + auto iter = actual.attr().find(e.first); + if (!AreAttrValuesEqual(e.second, iter->second)) { + if (diff != nullptr) { + *diff = strings::StrCat( + "Node named '", actual.name(), "' has attr '", e.first, + "' with value: ", SummarizeAttrValue(iter->second), + " that does not match expected: ", SummarizeAttrValue(e.second)); + } + return false; + } + } + if (!actual_attr.empty()) { + if (diff != nullptr) { + *diff = strings::StrCat( + "Node named '", actual.name(), "' has unexpected attr '", + *actual_attr.begin(), "' with value: ", + SummarizeAttrValue(actual.attr().find(*actual_attr.begin())->second)); + } + return false; + } + + return true; +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/equal_graph_def.h b/tensorflow/core/graph/equal_graph_def.h new file mode 100644 index 0000000000..7dd8aab340 --- /dev/null +++ b/tensorflow/core/graph/equal_graph_def.h @@ -0,0 +1,32 @@ +#ifndef TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_ +#define TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Determines if actual and expected are equal, ignoring ordering of +// nodes, attrs, and control inputs. If the GraphDefs are different +// and diff != nullptr, *diff is set to an explanation of the +// difference. Note that we use node names to match up nodes between +// the graphs, and so the naming of nodes must be consistent. +bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, + string* diff); + +// Determines if actual and expected are equal, ignoring ordering of +// attrs and control inputs. If the NodeDefs are different and +// diff != nullptr, *diff is set to an explanation of the difference. +bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff); + +#define TF_EXPECT_GRAPH_EQ(expected, actual) \ + do { \ + string diff; \ + EXPECT_TRUE(EqualGraphDef(actual, expected, &diff)) \ + << diff << "\nActual: " << SummarizeGraphDef(actual); \ + } while (false) + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_ diff --git a/tensorflow/core/graph/equal_graph_def_test.cc b/tensorflow/core/graph/equal_graph_def_test.cc new file mode 100644 index 0000000000..3a38b9e522 --- /dev/null +++ b/tensorflow/core/graph/equal_graph_def_test.cc @@ -0,0 +1,279 @@ +#include "tensorflow/core/graph/equal_graph_def.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/ops_util.h" +#include + +namespace tensorflow { +namespace { + +REGISTER_OP("Input").Output("o: float"); +REGISTER_OP("Alternate").Output("o: float"); +REGISTER_OP("Cross").Input("a: float").Input("b: float").Output("o: float"); + +Node* Input(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("Input", opts); +} + +Node* Alternate(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("Alternate", opts); +} + +Node* Cross(ops::NodeOut a, ops::NodeOut b, + const GraphDefBuilder::Options& opts) { + return ops::BinaryOp("Cross", a, b, opts); +} + +class EqualGraphDefTest : public ::testing::Test { + protected: + EqualGraphDefTest() + : e_(GraphDefBuilder::kFailImmediately), + a_(GraphDefBuilder::kFailImmediately) { + RequireDefaultOps(); + } + + bool Match() { + GraphDef expected; + e_.ToGraphDef(&expected); + GraphDef actual; + a_.ToGraphDef(&actual); + return EqualGraphDef(actual, expected, &diff_); + } + + GraphDefBuilder e_; + GraphDefBuilder a_; + string diff_; +}; + +TEST_F(EqualGraphDefTest, Match) { + Input(e_.opts().WithName("A")); + Input(a_.opts().WithName("A")); + EXPECT_TRUE(Match()) << diff_; +} + +TEST_F(EqualGraphDefTest, NoMatch) { + Input(e_.opts().WithName("A")); + Input(a_.opts().WithName("B")); + EXPECT_FALSE(Match()); + EXPECT_EQ("Did not find expected node 'A = Input[]()'", diff_); +} + +TEST_F(EqualGraphDefTest, MissingNode) { + Input(e_.opts().WithName("A")); + Input(e_.opts().WithName("B")); + Input(a_.opts().WithName("A")); + EXPECT_FALSE(Match()); + EXPECT_EQ("Did not find expected node 'B = Input[]()'", diff_); +} + +TEST_F(EqualGraphDefTest, ExtraNode) { + Input(e_.opts().WithName("A")); + Input(a_.opts().WithName("A")); + Input(a_.opts().WithName("B")); + EXPECT_FALSE(Match()); + EXPECT_EQ( + "Found unexpected node 'B = Input[]()' not in expected graph:\n" + "A = Input[]();\n", + diff_); +} + +TEST_F(EqualGraphDefTest, NodeOrder) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Cross(a, b, e_.opts().WithName("C")); + + b = Input(a_.opts().WithName("B")); + a = Input(a_.opts().WithName("A")); + Cross(a, b, a_.opts().WithName("C")); + EXPECT_TRUE(Match()) << diff_; +} + +TEST_F(EqualGraphDefTest, NameMismatch) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + // Have to call EqualNodeDef() directly here, since EqualGraphDef() + // only calls EqualNodeDef() with nodes that have matching names. + EXPECT_FALSE(EqualNodeDef(a->def(), b->def(), &diff_)); + EXPECT_EQ("Actual node name 'A' is not expected 'B'", diff_); +} + +TEST_F(EqualGraphDefTest, OpMismatch) { + Input(e_.opts().WithName("A")); + Alternate(a_.opts().WithName("A")); + EXPECT_FALSE(Match()); + EXPECT_EQ("Node named 'A' has op 'Alternate' that is not expected 'Input'", + diff_); +} + +TEST_F(EqualGraphDefTest, DeviceMatch) { + Input(e_.opts().WithName("A").WithDevice("/cpu:0")); + Input(a_.opts().WithName("A").WithDevice("/cpu:0")); + EXPECT_TRUE(Match()) << diff_; +} + +TEST_F(EqualGraphDefTest, DeviceMismatch) { + Input(e_.opts().WithName("A").WithDevice("/cpu:0")); + Input(a_.opts().WithName("A").WithDevice("/cpu:1")); + EXPECT_FALSE(Match()); + EXPECT_EQ("Node named 'A' has device '/cpu:1' that is not expected '/cpu:0'", + diff_); +} + +TEST_F(EqualGraphDefTest, InputMismatch) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Cross(a, a, e_.opts().WithName("C")); + + a = Input(a_.opts().WithName("A")); + b = Input(a_.opts().WithName("B")); + Cross(b, b, a_.opts().WithName("C")); + EXPECT_FALSE(Match()); + EXPECT_EQ("Node named 'C' has input 0 'B' that doesn't match expected 'A'", + diff_); +} + +TEST_F(EqualGraphDefTest, InputOrderMismatch) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Cross(a, b, e_.opts().WithName("C")); + + a = Input(a_.opts().WithName("A")); + b = Input(a_.opts().WithName("B")); + Cross(b, a, a_.opts().WithName("C")); + EXPECT_FALSE(Match()); + EXPECT_EQ("Node named 'C' has input 0 'B' that doesn't match expected 'A'", + diff_); +} + +TEST_F(EqualGraphDefTest, ControlInputOrder) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Node* c = Input(e_.opts().WithName("C")); + Node* d = Input(e_.opts().WithName("D")); + Cross(a, a, e_.opts() + .WithName("E") + .WithControlInput(b) + .WithControlInput(c) + .WithControlInput(d)); + + a = Input(a_.opts().WithName("A")); + b = Input(a_.opts().WithName("B")); + c = Input(a_.opts().WithName("C")); + d = Input(a_.opts().WithName("D")); + Cross(a, a, a_.opts() + .WithName("E") + .WithControlInput(c) + .WithControlInput(d) + .WithControlInput(b)); + EXPECT_TRUE(Match()) << diff_; +} + +TEST_F(EqualGraphDefTest, ControlInputMismatch) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Node* c = Input(e_.opts().WithName("C")); + Node* d = Input(e_.opts().WithName("D")); + Cross(a, a, e_.opts().WithName("E").WithControlInput(b).WithControlInput(c)); + + a = Input(a_.opts().WithName("A")); + b = Input(a_.opts().WithName("B")); + c = Input(a_.opts().WithName("C")); + d = Input(a_.opts().WithName("D")); + Cross(a, a, a_.opts().WithName("E").WithControlInput(b).WithControlInput(d)); + EXPECT_FALSE(Match()); + EXPECT_EQ("Node named 'E' missing expected control input '^C'", diff_); +} + +TEST_F(EqualGraphDefTest, ControlInputAdded) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Node* c = Input(e_.opts().WithName("C")); + Cross(a, a, e_.opts().WithName("D").WithControlInput(b)); + + a = Input(a_.opts().WithName("A")); + b = Input(a_.opts().WithName("B")); + c = Input(a_.opts().WithName("C")); + Cross(a, a, a_.opts().WithName("D").WithControlInput(b).WithControlInput(c)); + EXPECT_FALSE(Match()); + EXPECT_EQ( + "Node named 'D' has inputs 'A, A, ^B, ^C' that don't match " + "expected 'A, A, ^B'", + diff_); +} + +TEST_F(EqualGraphDefTest, ControlInputRemoved) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Node* c = Input(e_.opts().WithName("C")); + Cross(a, a, e_.opts().WithName("D").WithControlInput(b).WithControlInput(c)); + + a = Input(a_.opts().WithName("A")); + b = Input(a_.opts().WithName("B")); + c = Input(a_.opts().WithName("C")); + Cross(a, a, a_.opts().WithName("D").WithControlInput(b)); + EXPECT_FALSE(Match()); + EXPECT_EQ( + "Node named 'D' has inputs 'A, A, ^B' that don't match " + "expected 'A, A, ^B, ^C'", + diff_); +} + +TEST_F(EqualGraphDefTest, Attr) { + Node* a = Input(e_.opts().WithName("A")); + NodeDef same(a->def()); + AddNodeAttr("foo", "bar", &same); + EXPECT_TRUE(EqualNodeDef(same, same, &diff_)) << diff_; +} + +TEST_F(EqualGraphDefTest, AttrAdded) { + Node* a = Input(e_.opts().WithName("A")); + NodeDef actual(a->def()); + AddNodeAttr("foo", "bar", &actual); + EXPECT_FALSE(EqualNodeDef(actual, a->def(), &diff_)); + EXPECT_EQ("Node named 'A' has unexpected attr 'foo' with value: \"bar\"", + diff_); +} + +TEST_F(EqualGraphDefTest, AttrRemoved) { + Node* a = Input(e_.opts().WithName("A")); + NodeDef expected(a->def()); + AddNodeAttr("foo", "bar", &expected); + EXPECT_FALSE(EqualNodeDef(a->def(), expected, &diff_)); + EXPECT_EQ("Node named 'A' missing expected attr 'foo' with value: \"bar\"", + diff_); +} + +TEST_F(EqualGraphDefTest, AttrOrder) { + Node* a = Input(e_.opts().WithName("A")); + NodeDef actual(a->def()); + AddNodeAttr("foo", "bar", &actual); + AddNodeAttr("baz", 42, &actual); + + NodeDef expected(a->def()); + AddNodeAttr("baz", 42, &expected); + AddNodeAttr("foo", "bar", &expected); + + EXPECT_TRUE(EqualNodeDef(actual, expected, &diff_)) << diff_; +} + +TEST_F(EqualGraphDefTest, AttrMismatch) { + Node* a = Input(e_.opts().WithName("A")); + NodeDef actual(a->def()); + AddNodeAttr("foo", "bar", &actual); + AddNodeAttr("baz", 5, &actual); + + NodeDef expected(a->def()); + AddNodeAttr("baz", 42, &expected); + AddNodeAttr("foo", "bar", &expected); + + EXPECT_FALSE(EqualNodeDef(actual, expected, &diff_)); + EXPECT_EQ( + "Node named 'A' has attr 'baz' with value: 5 that does not match " + "expected: 42", + diff_); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc new file mode 100644 index 0000000000..0c268a51a9 --- /dev/null +++ b/tensorflow/core/graph/graph.cc @@ -0,0 +1,319 @@ +#include "tensorflow/core/graph/graph.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// Node + +string Node::DebugString() const { + if (this == nullptr) { + return "{nullptr}"; + } + string ret = strings::StrCat("{name:'", name(), "' id:", id_); + if (IsSource()) { + strings::StrAppend(&ret, " source}"); + } else if (IsSink()) { + strings::StrAppend(&ret, " sink}"); + } else { + strings::StrAppend(&ret, " op device:"); + strings::StrAppend(&ret, "{", assigned_device_name_, "}"); + strings::StrAppend(&ret, " def:{", SummarizeNodeDef(def()), "}}"); + } + return ret; +} + +Node::Node() + : id_(-1), cost_id_(-1), props_(nullptr), assigned_device_name_() {} + +Node::~Node() { + if (props_) { + props_->Unref(); + } +} + +void Node::Initialize(int id, int cost_id, Properties* props) { + DCHECK_EQ(id_, -1); + DCHECK(in_edges_.empty()); + DCHECK(out_edges_.empty()); + id_ = id; + cost_id_ = cost_id; + + // Unref the old, assign the new properties. + if (props_) { + props_->Unref(); + } + props_ = props; +} + +void Node::Clear() { + in_edges_.clear(); + out_edges_.clear(); + id_ = -1; + cost_id_ = -1; + + if (props_) { + props_->Unref(); + props_ = nullptr; + } + + assigned_device_name_.clear(); +} + +gtl::iterator_range Node::out_nodes() const { + return gtl::make_range(NeighborIter(out_edges_.begin(), false), + NeighborIter(out_edges_.end(), false)); +} + +gtl::iterator_range Node::in_nodes() const { + return gtl::make_range(NeighborIter(in_edges_.begin(), true), + NeighborIter(in_edges_.end(), true)); +} + +// Node::Properties + +Node::Properties::Properties(const OpDef* op_def, const NodeDef& node_def, + const DataTypeSlice inputs, + const DataTypeSlice outputs) + : op_def_(op_def), + node_def_(node_def), + input_types_(inputs.begin(), inputs.end()), + output_types_(outputs.begin(), outputs.end()) {} + +Node::Properties::~Properties() {} + +// Graph + +Graph::Graph(const OpRegistryInterface* ops) + : ops_(ops), arena_(8 << 10 /* 8kB */) { + // Source and sink have no endpoints, just control edges. + NodeDef def; + def.set_name("_SOURCE"); + def.set_op("NoOp"); + Status status; + Node* source = AddNode(def, &status); + TF_CHECK_OK(status); + CHECK_EQ(source->id(), kSourceId); + + def.set_name("_SINK"); + Node* sink = AddNode(def, &status); + TF_CHECK_OK(status); + CHECK_EQ(sink->id(), kSinkId); + + AddControlEdge(source, sink); +} + +Graph::~Graph() { + // Manually call the destructors for all the Nodes we constructed using + // placement new. + for (Node* node : nodes_) { + if (node != nullptr) { + node->~Node(); + } + } + for (Node* node : free_nodes_) { + node->~Node(); + } + // Edges have no destructor, and we arena-allocated them, so no need to + // destroy them. +} + +Node* Graph::AddNode(const NodeDef& node_def, Status* status) { + const OpDef* op_def = ops_->LookUp(node_def.op(), status); + if (op_def == nullptr) return nullptr; + + // TODO(vrv,josh11b): Find a location higher in the stack to add these defaults + // to the NodeDef. + NodeDef node_def_with_defaults(node_def); + AddDefaultsToNodeDef(*op_def, &node_def_with_defaults); + + DataTypeVector inputs; + DataTypeVector outputs; + status->Update( + InOutTypesForNode(node_def_with_defaults, *op_def, &inputs, &outputs)); + if (!status->ok()) { + *status = AttachDef(*status, node_def_with_defaults); + return nullptr; + } + + Node* node = AllocateNode( + new Node::Properties(op_def, node_def_with_defaults, inputs, outputs), + nullptr); + return node; +} + +Node* Graph::CopyNode(Node* node) { + DCHECK(!node->IsSource()); + DCHECK(!node->IsSink()); + Node::Properties* props = node->properties(); + props->Ref(); + Node* copy = AllocateNode(props, node); + copy->set_assigned_device_name(node->assigned_device_name()); + return copy; +} + +void Graph::RemoveNode(Node* node) { + DCHECK(IsValidNode(node)) << node->DebugString(); + DCHECK(!node->IsSource()); + DCHECK(!node->IsSink()); + + // Remove any edges involving this node. + while (!node->in_edges_.empty()) { + RemoveEdge(*node->in_edges_.begin()); + } + while (!node->out_edges_.empty()) { + RemoveEdge(*node->out_edges_.begin()); + } + ReleaseNode(node); +} + +const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) { + DCHECK(IsValidNode(source)) << source->DebugString(); + DCHECK(IsValidNode(dest)) << dest->DebugString(); + + // source/sink must only be linked via control slots, and + // control slots must only be linked to control slots. + if (source == source_node() || dest == sink_node() || x == kControlSlot || + y == kControlSlot) { + DCHECK_EQ(x, kControlSlot) << source->DebugString(); + DCHECK_EQ(y, kControlSlot) << dest->DebugString(); + } + + Edge* e = nullptr; + if (free_edges_.empty()) { + e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new + } else { + e = free_edges_.back(); + free_edges_.pop_back(); + } + e->id_ = edges_.size(); + e->src_ = source; + e->dst_ = dest; + e->src_output_ = x; + e->dst_input_ = y; + CHECK(source->out_edges_.insert(e).second); + CHECK(dest->in_edges_.insert(e).second); + edges_.push_back(e); + edge_set_.insert(e); + return e; +} + +void Graph::RemoveEdge(const Edge* e) { + DCHECK(IsValidNode(e->src_)) << e->src_->DebugString(); + DCHECK(IsValidNode(e->dst_)) << e->dst_->DebugString(); + CHECK_EQ(e->src_->out_edges_.erase(e), 1); + CHECK_EQ(e->dst_->in_edges_.erase(e), 1); + CHECK_EQ(e, edges_[e->id_]); + + CHECK_EQ(edge_set_.erase(e), 1); + edges_[e->id_] = nullptr; + + Edge* del = const_cast(e); + del->src_ = nullptr; + del->dst_ = nullptr; + del->id_ = -1; + del->src_output_ = kControlSlot - 1; + del->dst_input_ = kControlSlot - 1; + free_edges_.push_back(del); +} + +namespace { + +void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { + if (src_slot == Graph::kControlSlot) { + dst->add_input(strings::StrCat("^", src_name)); + } else if (src_slot == 0) { + dst->add_input(src_name.data(), src_name.size()); + } else { + dst->add_input(strings::StrCat(src_name, ":", src_slot)); + } +} + +} // namespace + +void Graph::ToGraphDef(GraphDef* graph_def) const { + graph_def->Clear(); + std::vector + inputs; // Construct this outside the loop for speed. + for (const Node* node : nodes()) { + if (!node->IsOp()) continue; + NodeDef* node_def = graph_def->add_node(); + *node_def = node->def(); + + // Use the node's assigned device, if any, instead of the device requested + // in the NodeDef. + if (!node->assigned_device_name().empty()) { + node_def->set_device(node->assigned_device_name()); + } + + // Get the inputs for this Node. We make sure control inputs are + // after data inputs, as required by GraphDef. + inputs.clear(); + inputs.resize(node->num_inputs(), nullptr); + for (const Edge* edge : node->in_edges()) { + if (edge->IsControlEdge()) { + inputs.push_back(edge); + } else { + DCHECK(inputs[edge->dst_input()] == nullptr); + inputs[edge->dst_input()] = edge; + } + } + node_def->clear_input(); + for (size_t i = 0; i < inputs.size(); ++i) { + const Edge* edge = inputs[i]; + if (edge == nullptr) { + node_def->add_input(node->def().input(i)); + } else { + const Node* src = edge->src(); + if (!src->IsOp()) continue; + AddInput(node_def, src->name(), edge->src_output()); + } + } + } +} + +string Graph::NewName(StringPiece prefix) { + return strings::StrCat(prefix, "/_", name_counter_++); +} + +gtl::iterator_range Graph::nodes() const { + // Note that NodeId 0 is always valid since we don't let the source + // node be removed from the graph. + return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids())); +} + +bool Graph::IsValidNode(Node* node) const { + if (node == nullptr) return false; + const int id = node->id(); + if (id < 0 || static_cast(id) >= nodes_.size()) return false; + return nodes_[id] == node; +} + +Node* Graph::AllocateNode(Node::Properties* props, const Node* cost_node) { + Node* node = nullptr; + if (free_nodes_.empty()) { + node = new (arena_.Alloc(sizeof(Node))) Node; // placement new + } else { + node = free_nodes_.back(); + free_nodes_.pop_back(); + } + const int id = nodes_.size(); + int cost_id = cost_node ? cost_node->cost_id() : id; + node->Initialize(id, cost_id, props); + nodes_.push_back(node); + return node; +} + +void Graph::ReleaseNode(Node* node) { + DCHECK(IsValidNode(node)) << node->DebugString(); + nodes_[node->id()] = nullptr; + free_nodes_.push_back(node); + node->Clear(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h new file mode 100644 index 0000000000..030e471bf4 --- /dev/null +++ b/tensorflow/core/graph/graph.h @@ -0,0 +1,440 @@ +// A Graph describes a set of computations that are to be +// performed, as well as the dependencies between those +// compuations. The basic model is a DAG (directed acyclic graph) with +// * internal nodes representing computational operations to be performed; +// * edges represent dependencies, indicating the target may only be +// executed once the source has completed; and +// * predefined "source" (start) and "sink" (finish) nodes -- the source +// should be the only node that doesn't depend on anything, and the sink +// should be the only node that nothing depends on. +// +// Note: Node ids are intended to be relatively dense in the +// 0..max_id range, but there may be gaps since ids won't be reused. +// +// Note: Some dependencies between operations are due to one operation +// consuming the output of another. In fact operations can produce +// multiple outputs and consume multiple inputs, and some +// optimizations will care about which specific outputs are connected +// to which specific inputs. We therefore represent data dependency +// between output O of layer A and input I of layer B using +// "input index" and "output index" labels per edge. + +#ifndef TENSORFLOW_GRAPH_GRAPH_H_ +#define TENSORFLOW_GRAPH_GRAPH_H_ + +#include +#include +#include +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/edgeset.h" +#include "tensorflow/core/lib/core/arena.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/iterator_range.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class Edge; +class EdgeSetTest; +class Graph; +class Node; + +class NeighborIter; // Declared below +class NodeIter; // Declared below + +class Node { + public: + string DebugString() const; + int id() const { return id_; } + int cost_id() const { return cost_id_; } + const string& name() const { return props_->node_def_.name(); } + const string& type_string() const { return props_->node_def_.op(); } + const NodeDef& def() const { return props_->node_def_; } + const OpDef& op_def() const { return *props_->op_def_; } + + // input and output types + int num_inputs() const { return props_->input_types_.size(); } + DataType input_type(int i) const { return props_->input_types_[i]; } + const DataTypeVector& input_types() const { return props_->input_types_; } + + int num_outputs() const { return props_->output_types_.size(); } + DataType output_type(int o) const { return props_->output_types_[o]; } + const DataTypeVector& output_types() const { return props_->output_types_; } + + // This gives the device the runtime has assigned this node to. If + // you want the device the user requested, use def().device() instead. + // TODO(josh11b): Validate that the assigned_device, if not empty: + // fully specifies a device, and satisfies def().device(). + // TODO(josh11b): Move device_name outside of Node into a NodeId->DeviceName + // map. + string assigned_device_name() const { return assigned_device_name_; } + void set_assigned_device_name(const string& device_name) { + assigned_device_name_ = device_name; + } + + // Get the neighboring nodes via edges either in or out of this node. + gtl::iterator_range in_nodes() const; + gtl::iterator_range out_nodes() const; + const EdgeSet& in_edges() const { return in_edges_; } + const EdgeSet& out_edges() const { return out_edges_; } + + // Node type helpers. + bool IsSource() const { return id() == 0; } + bool IsSink() const { return id() == 1; } + // Anything other than the special Source & Sink nodes. + bool IsOp() const { return id() > 1; } + + private: + friend class Graph; + Node(); + ~Node(); + + class Properties : public core::RefCounted { + public: + Properties(const OpDef* op_def, const NodeDef& node_def, + const DataTypeSlice inputs, const DataTypeSlice outputs); + + const OpDef* op_def_; // not owned + const NodeDef node_def_; + const DataTypeVector input_types_; + const DataTypeVector output_types_; + + private: + // Destructor invoked when last reference goes away via Unref() + virtual ~Properties(); + TF_DISALLOW_COPY_AND_ASSIGN(Properties); + }; + + Properties* properties() const { return props_; } + + // Initialize() adopts a reference to props, and so is suitable if props was + // just allocated or you call props->Ref() to increment the reference + // count for a props being held by another Node. + void Initialize(int id, int cost_id, Properties* props); + // Releases memory from props_, in addition to restoring *this to its + // uninitialized state. + void Clear(); + + int id_; // -1 until Initialize() is called + int cost_id_; // -1 if there is no corresponding cost accounting node + + EdgeSet in_edges_; + EdgeSet out_edges_; + + Properties* props_; + + // Name of device assigned to perform this computation. + string assigned_device_name_; + + TF_DISALLOW_COPY_AND_ASSIGN(Node); +}; + +class Edge { + public: + Node* src() const { return src_; } + Node* dst() const { return dst_; } + int id() const { return id_; } + + // Return the number of the source output that produces the data + // carried by this edge. The special value kControlSlot is used + // for control dependencies. + int src_output() const { return src_output_; } + + // Return the number of the destination input that consumes the data + // carried by this edge. The special value kControlSlot is used + // for control dependencies. + int dst_input() const { return dst_input_; } + + // Return true iff this is an edge that indicates a control-flow + // (as opposed to a data-flow) dependency. + bool IsControlEdge() const; + + private: + Edge() {} + + friend class EdgeSetTest; + friend class Graph; + Node* src_; + Node* dst_; + int id_; + int src_output_; + int dst_input_; +}; + +// Thread compatible but not thread safe. +class Graph { + public: + // Constructs a graph with a single SOURCE (always id kSourceId) and a + // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK. + // + // The graph can hold ops found in registry. + explicit Graph(const OpRegistryInterface* registry); + ~Graph(); + + static const int kControlSlot = -1; + + // Adds a new node to this graph, and returns it. Infers the Op and + // input/output types for the node. *this owns the returned instance. + // Returns nullptr and sets *status on error. + Node* AddNode(const NodeDef& node_def, Status* status); + + // Copies *node, which may belong to another graph, to a new node, + // which is returned. Does not copy any edges. *this owns the + // returned instance. + Node* CopyNode(Node* node); + + // Remove a node from this graph, including all edges from or to it. + // *node should not be accessed after calling this function. + // REQUIRES: node->IsOp() + void RemoveNode(Node* node); + + // Add an edge that connects the xth output of "source" to the yth input + // of "dest". + const Edge* AddEdge(Node* source, int x, Node* dest, int y); + + // Add a control-edge (no data flows along this edge) that + // connects "source" to "dest". + const Edge* AddControlEdge(Node* source, Node* dest) { + return AddEdge(source, kControlSlot, dest, kControlSlot); + } + + // Removes edge from the graph. + // REQUIRES: The edge must exist. + void RemoveEdge(const Edge* edge); + + // Returns one more than the maximum id assigned to any node. + int num_node_ids() const { return nodes_.size(); } + + // Serialize to a GraphDef. + void ToGraphDef(GraphDef* graph_def) const; + + // Generate new node name with the specified prefix that is unique + // across this graph. + string NewName(StringPiece prefix); + + // Access to the list of all nodes. Example usage: + // for (Node* node : graph.nodes()) { ... } + gtl::iterator_range nodes() const; + + // Returns the node associated with an id, or nullptr if no node + // with that id (the node with that id was removed and the id has + // not yet been re-used). *this owns the returned instance. + // REQUIRES: 0 <= id < num_node_ids(). + Node* FindNodeId(int id) const { return nodes_[id]; } + + // Returns one more than the maximum id assigned to any edge. + int num_edge_ids() const { return edges_.size(); } + + // Returns the Edge associated with an id, or nullptr if no edge + // with that id (the node with that id was removed and the id has + // not yet been re-used). *this owns the returned instance. + // REQUIRES: 0 <= id < num_node_ids(). + const Edge* FindEdgeId(int id) const { return edges_[id]; } + + // Access to the set of all edges. Example usage: + // for (const Edge* e : graph.edges()) { ... } + const EdgeSet& edges() const { return edge_set_; } + + // The pre-defined nodes. + enum { kSourceId = 0, kSinkId = 1 }; + Node* source_node() const { return FindNodeId(kSourceId); } + Node* sink_node() const { return FindNodeId(kSinkId); } + + const OpRegistryInterface* op_registry() const { return ops_; } + + // TODO(josh11b): uint64 hash() const; + + private: + bool IsValidNode(Node* node) const; + // If cost_node is non-null, then cost accounting (in CostModel) + // will be associated with that node rather than the new one being + // created. + Node* AllocateNode(Node::Properties* props, const Node* cost_node); + void ReleaseNode(Node* node); + + // Registry of all known ops. Not owned. + const OpRegistryInterface* const ops_; + + // Allocator which will give us good locality. + core::Arena arena_; + + // Map from node ids to allocated nodes. nodes_[id] may be nullptr if + // the node with that id was removed from the graph. + std::vector nodes_; + + // Map from edge ids to allocated edges. edges_[id] may be nullptr if + // the edge with that id was removed from the graph. + std::vector edges_; + + // For ease of iteration, we currently just keep a set of all live + // edges. May want to optimize by removing this copy. + EdgeSet edge_set_; + + // Allocated but free nodes and edges. + std::vector free_nodes_; + std::vector free_edges_; + + // For generating unique names. + int name_counter_ = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(Graph); +}; + +// TODO(josh11b): We may want to support keeping an index on various +// node/edge attributes in a graph, particularly node names. + +// Helper routines + +inline bool IsSwitch(const Node* node) { + return node->type_string() == "Switch" || node->type_string() == "RefSwitch"; +} + +inline bool IsMerge(const Node* node) { return node->type_string() == "Merge"; } + +inline bool IsEnter(const Node* node) { + return node->type_string() == "Enter" || node->type_string() == "RefEnter"; +} + +inline bool IsExit(const Node* node) { return node->type_string() == "Exit"; } + +inline bool IsNextIteration(const Node* node) { + return node->type_string() == "NextIteration"; +} + +inline bool IsLoopCond(const Node* node) { + return node->type_string() == "LoopCond"; +} + +inline bool IsControlTrigger(const Node* node) { + return node->type_string() == "ControlTrigger"; +} + +inline bool IsSend(const Node* node) { + return node->type_string() == "_Send" || node->type_string() == "_HostSend"; +} + +inline bool IsRecv(const Node* node) { + return node->type_string() == "_Recv" || node->type_string() == "_HostRecv"; +} + +// True for Nodes that mediate the transfer of values between processes. +inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); } + +inline bool IsConstant(const Node* node) { + return node->type_string() == "Const" || node->type_string() == "HostConst"; +} + +inline bool IsVariable(const Node* node) { + return node->type_string() == "Variable"; +} + +inline bool IsIdentity(const Node* node) { + return (node->type_string() == "Identity" || + node->type_string() == "RefIdentity"); +} + +// Returns true iff 'n' is a control flow node. +inline bool IsControlFlow(const Node* n) { + return IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n) || + IsNextIteration(n); +} + +inline bool IsHostMemoryPreserving(const Node* node) { + return IsIdentity(node) || IsControlFlow(node); +} + +// Iterator for stepping through the nodes of a graph. +class NodeIter { + public: + NodeIter(const Graph* graph, int id); + bool operator==(const NodeIter& rhs); + bool operator!=(const NodeIter& rhs); + void operator++(); + Node* operator*(); + Node* operator->(); + + private: + // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr + const Graph* graph_; + int id_; +}; + +// Iterator for stepping through the neighbors of a node. +class NeighborIter { + public: + NeighborIter(EdgeSet::const_iterator iter, bool incoming); + bool operator==(const NeighborIter& rhs); + bool operator!=(const NeighborIter& rhs); + void operator++(); + Node* operator*(); + Node* operator->(); + + private: + EdgeSet::const_iterator iter_; + bool incoming_; +}; + +// IMPLEMENTATION DETAILS, PLEASE IGNORE + +inline NodeIter::NodeIter(const Graph* graph, int id) + : graph_(graph), id_(id) {} + +inline bool NodeIter::operator==(const NodeIter& rhs) { + DCHECK(graph_ == rhs.graph_); + return id_ == rhs.id_; +} + +inline bool NodeIter::operator!=(const NodeIter& rhs) { + return !(*this == rhs); +} + +inline void NodeIter::operator++() { + while (1) { + DCHECK_LE(id_, graph_->num_node_ids()); + ++id_; + if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) { + return; + } + } +} + +inline Node* NodeIter::operator*() { return graph_->FindNodeId(id_); } + +inline Node* NodeIter::operator->() { return graph_->FindNodeId(id_); } + +inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming) + : iter_(iter), incoming_(incoming) {} + +inline bool NeighborIter::operator==(const NeighborIter& rhs) { + return iter_ == rhs.iter_ && incoming_ == rhs.incoming_; +} + +inline bool NeighborIter::operator!=(const NeighborIter& rhs) { + return !(*this == rhs); +} + +inline void NeighborIter::operator++() { ++iter_; } + +inline Node* NeighborIter::operator*() { + const Edge* e = *iter_; + return incoming_ ? e->src() : e->dst(); +} + +inline Node* NeighborIter::operator->() { + const Edge* e = *iter_; + return incoming_ ? e->src() : e->dst(); +} + +inline bool Edge::IsControlEdge() const { + // Note that if either src_output_ or dst_input_ is kControlSlot, + // so is the other one (AddEdge checks this). + return src_output_ == Graph::kControlSlot; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_GRAPH_H_ diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc new file mode 100644 index 0000000000..3928348f0a --- /dev/null +++ b/tensorflow/core/graph/graph_constructor.cc @@ -0,0 +1,385 @@ +#include "tensorflow/core/graph/graph_constructor.h" + +#include +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/optimizer_cse.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +namespace { +inline bool IsMerge(const NodeDef& node_def) { + return node_def.op() == "Merge"; +} +} // namespace + +namespace { + +class GraphConstructor { + public: + GraphConstructor(const GraphConstructorOptions& opts, const GraphDef* gdef, + Graph* g, Status* status) + : opts_(opts), gdef_(gdef), g_(g), status_(status) { + BuildNodeIndex(); + InitFromEdges(); + Convert(); + } + + private: + void SetError(const string& error); + void SetNodeError(const NodeDef& node_def, const StringPiece& message) { + SetError(strings::StrCat("Node '", node_def.name(), "': ", message)); + } + void BuildNodeIndex(); + void InitFromEdges(); + Node* MakeNode(const NodeDef& node_def); + void Convert(); + // Calls SetError() and returns false if the type of the output of + // the source of the edge can't be consumed by destination of the edge. + // REQUIRES: edge must be a data edge, not a control edge. + bool TypeValidateEdge(const Edge* edge); + + // From constructor + const GraphConstructorOptions opts_; + const GraphDef* gdef_; + Graph* g_; + Status* status_; + + // Mapping from node name to the index within gdef_ + struct NodeInfo { + explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {} + // std::unordered_map<> requires that we have a default constructor. + NodeInfo() : NodeInfo(-1) {} + int gdef_index; + Node* node; // nullptr until the NodeDef is converted to a Node. + }; + // TODO(vrv): Profile this data structure to see if we should use an + // alternative implementation of std::unordered_map. + std::unordered_map name_index_; + + // Index of NodeDefs in gdef_ with all inputs already converted. + std::vector ready_; + + // Mapping between index within gdef_ and the number of inputs that + // still need to be converted. + std::vector pending_count_; + + // Mapping between index within gdef_ and the index within gdef_ of + // all nodes it outputs to. + std::vector> outputs_; + + // Used in the conversion from gdef_ to g_ to represent the ith input + // of a node. + struct InputInfo { + explicit InputInfo(StringPiece node_name, Node* n, int i) + : name(node_name), node(n), index(i) {} + StringPiece name; + Node* node; + int index; + }; + + // Used in the conversion from gdef_ to g_ to represent an edge from + // the node named 'name' to node 'n'. + struct EdgeInfo { + explicit EdgeInfo(StringPiece name, int i1, Node* n, int i2) + : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {} + StringPiece src_name; + int src_index; + Node* dst_node; + int dst_index; + }; +}; + +void GraphConstructor::SetError(const string& error) { + status_->Update(errors::InvalidArgument(error)); +} + +void GraphConstructor::BuildNodeIndex() { + // Initialized outside the loop for efficiency + const char* pattern; + if (opts_.allow_internal_ops) { + pattern = "[A-Za-z0-9._][A-Za-z0-9_.\\-/]*"; + } else { + pattern = "[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"; + } + RE2 node_name_re(pattern); + + // Validate the node names and add them to name_index_. + for (int n = 0; n < gdef_->node_size(); ++n) { + const NodeDef& node_def(gdef_->node(n)); + if (!RE2::FullMatch(node_def.name(), node_name_re)) { + SetNodeError(node_def, "Node name contains invalid characters"); + return; + } + if (!name_index_.insert(std::make_pair(StringPiece(node_def.name()), + NodeInfo(n))) + .second) { + SetNodeError(node_def, "Node name is not unique"); + return; + } + // Validate the operation's type. + if (node_def.op().empty()) { + SetNodeError(node_def, "Does not specify a type"); + return; + } + if (opts_.expect_device_spec && node_def.device().empty()) { + SetNodeError(node_def, strings::StrCat("Missing device specification.")); + return; + } + } +} + +void GraphConstructor::InitFromEdges() { + const int num_nodes = gdef_->node_size(); + ready_.reserve(num_nodes); + pending_count_.reserve(num_nodes); + outputs_.resize(num_nodes); + + // Parse the inputs for each node. + for (int n = 0; n < num_nodes; ++n) { + const NodeDef& node_def(gdef_->node(n)); + if (IsMerge(node_def)) { + // for merge only wait for one non-control input. + int32 num_control_edges = 0; + for (int i = 0; i < node_def.input_size(); ++i) { + StringPiece input_name(node_def.input(i)); + if (StringPiece(input_name).starts_with("^")) { + num_control_edges++; + } + } + pending_count_.push_back(num_control_edges + 1); + } else { + pending_count_.push_back(node_def.input_size()); + } + if (node_def.input_size() == 0) { + ready_.push_back(n); + continue; + } + for (int i = 0; i < node_def.input_size(); ++i) { + StringPiece input_name = node_def.input(i); + if (input_name.starts_with("^")) { + // Control dependence + input_name.remove_prefix(1); + } + TensorId id(ParseTensorName(input_name)); + auto iter = name_index_.find(id.first); + if (iter == name_index_.end()) { + SetNodeError(node_def, + strings::StrCat("Unknown input node ", node_def.input(i))); + return; + } + outputs_[iter->second.gdef_index].push_back(n); + } + } +} + +Node* GraphConstructor::MakeNode(const NodeDef& node_def) { + // Add the node to the graph. + Node* node = g_->AddNode(node_def, status_); + if (node == nullptr) return nullptr; + if (opts_.expect_device_spec) { + node->set_assigned_device_name(node_def.device()); + } + name_index_[node_def.name()].node = node; + return node; +} + +// Return the number of nodes in "g" +static int CountNodes(Graph* g) { + int nodes = 0; + for (Node* node : g->nodes()) { + VLOG(1) << node; // Dummy use to avoid compiler warning + nodes++; + } + return nodes; +} + +void GraphConstructor::Convert() { + std::vector inputs; + std::vector back_edges; + int processed = 0; + // Process the NodeDefs in topological order. + while (!ready_.empty()) { + int o = ready_.back(); + ready_.pop_back(); + ++processed; + const NodeDef& node_def(gdef_->node(o)); + inputs.clear(); + bool in_control_dependence = false; + bool has_data_back_edge = false; + for (int i = 0; i < node_def.input_size(); ++i) { + StringPiece input_name(node_def.input(i)); + if (StringPiece(input_name).starts_with("^")) { + // A control dependence + in_control_dependence = true; + input_name.remove_prefix(1); + } else { + if (in_control_dependence) { + SetNodeError(node_def, strings::StrCat( + "Control dependencies must come after ", + "regular dependencies: input ", input_name, + " of source node ", node_def.name())); + return; + } + } + TensorId id(ParseTensorName(input_name)); + auto iter = name_index_.find(id.first); + DCHECK(iter != name_index_.end()); + Node* src_node = iter->second.node; + if (in_control_dependence) { + inputs.push_back(InputInfo(id.first, src_node, -1)); + } else { + if (src_node == nullptr) { + has_data_back_edge = true; + inputs.push_back(InputInfo(id.first, src_node, id.second)); + } else { + if (id.second >= src_node->num_outputs()) { + SetNodeError( + node_def, + strings::StrCat("Connecting to invalid output ", id.second, + " of source node ", id.first, " which has ", + src_node->num_outputs(), " outputs")); + return; + } + inputs.push_back(InputInfo(id.first, src_node, id.second)); + } + } + } + if (has_data_back_edge && !IsMerge(node_def)) { + SetError(strings::StrCat( + node_def.name(), + " had a back edge. But only Merge can have back edges.")); + return; + } + + Node* node = MakeNode(node_def); + if (node == nullptr) return; + + // Add edges from inputs to *node to the graph. + for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs[i].node == nullptr) { + // Record this back edge, which will be added after all nodes + // are created. + back_edges.push_back( + EdgeInfo(inputs[i].name, inputs[i].index, node, i)); + } else if (inputs[i].index == -1) { + g_->AddControlEdge(inputs[i].node, node); + } else { + const Edge* edge = + g_->AddEdge(inputs[i].node, inputs[i].index, node, i); + if (!TypeValidateEdge(edge)) return; + } + } + + // Update pending_count_ for outputs. + for (size_t i = 0; i < outputs_[o].size(); ++i) { + const int output = outputs_[o][i]; + pending_count_[output]--; + if (pending_count_[output] == 0) { + ready_.push_back(output); + } + } + } + + // Add the back edges after all nodes are created. + for (auto e : back_edges) { + Node* src_node = name_index_[e.src_name].node; + if (e.src_index == -1) { + g_->AddControlEdge(src_node, e.dst_node); + } else { + const Edge* edge = + g_->AddEdge(src_node, e.src_index, e.dst_node, e.dst_index); + if (!TypeValidateEdge(edge)) return; + } + + VLOG(2) << "Add back edge: " << src_node->name() << " -> " + << e.dst_node->name(); + } + + if (processed < gdef_->node_size()) { + SetError( + strings::StrCat(gdef_->node_size() - processed, " nodes in a cycle")); + return; + } + + if (status_->ok()) { + FixupSourceAndSinkEdges(g_); + + if (opts_.optimizer_do_cse) { + if (!back_edges.empty()) { + LOG(WARNING) << "Not doing CSE. We need to figure out how to handle " + << "loops in the CSE phase."; + } else { + VLOG(1) << "Starting CSE: graph of " << CountNodes(g_) << " nodes"; + OptimizeCSE(g_, opts_.cse_consider_function); + VLOG(1) << "Finished CSE: graph of " << CountNodes(g_) << " nodes"; + } + } + } +} + +bool GraphConstructor::TypeValidateEdge(const Edge* edge) { + DataType src_out = edge->src()->output_type(edge->src_output()); + DataType dst_in = edge->dst()->input_type(edge->dst_input()); + if (!TypesCompatible(dst_in, src_out)) { + SetError(strings::StrCat( + "Input ", edge->dst_input(), " of node ", edge->dst()->name(), + " was passed ", DataTypeString(src_out), " from ", edge->src()->name(), + ":", edge->src_output(), " incompatible with expected ", + DataTypeString(dst_in), ".")); + return false; + } + return true; +} + +} // namespace + +// ---------------------------------------------------------------------------- +// ConvertGraphDefToGraph +// ---------------------------------------------------------------------------- + +Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, + const GraphDef& gdef, Graph* g) { + Status status; + GraphConstructor constructor(opts, &gdef, g, &status); + return status; +} + +// ---------------------------------------------------------------------------- +// CopyGraph +// ---------------------------------------------------------------------------- +void CopyGraph(const Graph& src, Graph* dest) { + for (Node* n : dest->nodes()) { + CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty"; + } + + // Copy the nodes + std::unordered_map + node_map; // "Node in src" -> "Node in *dest" + node_map[src.source_node()] = dest->source_node(); + node_map[src.sink_node()] = dest->sink_node(); + for (Node* n : src.nodes()) { + if (n->IsSource() || n->IsSink()) continue; + CHECK(n->IsOp()); + node_map[n] = dest->CopyNode(n); + } + + // Copy the edges + for (const Edge* e : src.edges()) { + Node* src_copy = node_map[e->src()]; + Node* dst_copy = node_map[e->dst()]; + dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h new file mode 100644 index 0000000000..cd1615ef6b --- /dev/null +++ b/tensorflow/core/graph/graph_constructor.h @@ -0,0 +1,43 @@ +#ifndef TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_ +#define TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Construct a graph *g out of a GraphDef gdef. Returns non-OK on +// error, in which case *g is left in an incomplete state. +struct GraphConstructorOptions { + // If true, allows internal ops in the GraphDef. + bool allow_internal_ops = false; + + // If true, the graph def is expected to have fully specified + // devices for all nodes. A node in the resulting graph "g" has the + // device name set accordingly. + // + // TODO(zhifengc): if possible, consider removing this option. + bool expect_device_spec = false; + + // If true, perform common subexpression elimination on the graph. + // TODO(jeff): Turn this default to true? + bool optimizer_do_cse = false; + + // If "optimizer_do_cse" is true and "cse_consider_function" is + // not nullptr, then only consider nodes for CSE for which + // "cse_consider_function(node)" returns true. + std::function cse_consider_function = nullptr; +}; +extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, + const GraphDef& gdef, Graph* g); + +// Make a copy of "src" into "*dest". +// +// REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges +// other than the implicit Source/Sink nodes. +extern void CopyGraph(const Graph& src, Graph* dest); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_ diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc new file mode 100644 index 0000000000..61f4427297 --- /dev/null +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -0,0 +1,190 @@ +#include "tensorflow/core/graph/graph_constructor.h" + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/public/status.h" +#include + +// TODO(josh11b): Test InitCostModel(). +// TODO(josh11b): Test setting the "device" field of a NodeDef. +// TODO(josh11b): Test that feeding won't prune targets. + +namespace tensorflow { +namespace { + +class GraphConstructorTest : public ::testing::Test { + protected: + GraphConstructorTest() : g_(new Graph(OpRegistry::Global())) { + RequireDefaultOps(); + } + ~GraphConstructorTest() override {} + + void Convert(const string& gdef_ascii) { + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_)); + } + + void ExpectError(const string& gdef_ascii, const string& expected_error_re) { + Convert(gdef_ascii); + GraphConstructorOptions opts; + Status status = ConvertGraphDefToGraph(opts, gdef_, g_.get()); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(RE2::PartialMatch(status.error_message(), expected_error_re)) + << status; + } + + void ExpectOK(const string& gdef_ascii) { + Convert(gdef_ascii); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get())); + } + + Node* FindNode(const string& name) { + for (Node* n : g_->nodes()) { + if (n->name() == name) return n; + } + return nullptr; + } + + bool HasNode(const string& name) { return FindNode(name) != nullptr; } + + void ExpectNodes(const string& nodes) { + int count = 0; + std::vector actual_nodes; + for (Node* n : g_->nodes()) { + if (n->IsOp()) { + count++; + actual_nodes.push_back(n->name()); + } + } + std::sort(actual_nodes.begin(), actual_nodes.end()); + + LOG(INFO) << "Nodes present: " << str_util::Join(actual_nodes, " "); + + std::vector expected_nodes = str_util::Split(nodes, ','); + std::sort(expected_nodes.begin(), expected_nodes.end()); + for (const string& s : expected_nodes) { + Node* n = FindNode(s); + EXPECT_TRUE(n != nullptr) << s; + } + + EXPECT_TRUE(actual_nodes.size() == expected_nodes.size()) + << "\nActual: " << str_util::Join(actual_nodes, ",") + << "\nExpected: " << str_util::Join(expected_nodes, ","); + } + + bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) { + for (const Edge* e : g_->edges()) { + if (e->src()->name() == src && e->src_output() == src_out && + e->dst()->name() == dst && e->dst_input() == src_out) + return true; + } + return false; + } + bool HasControlEdge(const string& src, const string& dst) { + return HasEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot); + } + + private: + GraphDef gdef_; + std::unique_ptr g_; +}; + +REGISTER_OP("ABC"); +REGISTER_OP("TestParams").Output("o: float"); +REGISTER_OP("TestInput").Output("a: float").Output("b: float"); +REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); +REGISTER_OP("TestInt").Input("a: int32"); + +TEST_F(GraphConstructorTest, InvalidNodeName) { + ExpectError("node { name: 'a:b' op: 'ABC' }", + "Node 'a:b': Node name contains invalid characters"); + ExpectError("node { name: '_abc' op: 'ABC' }", + // Can't start with '_' + "Node '_abc': Node name contains invalid characters"); + ExpectOK("node { name: 'a-bc_' op: 'ABC' }"); +} + +TEST_F(GraphConstructorTest, InvalidSourceNodeName) { + ExpectError( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: 'W999' input: 'input' }", + + "Unknown input node.*W999"); +} + +TEST_F(GraphConstructorTest, InvalidSourceNodeIndex) { + ExpectError( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1:1', 'input:1' ] }", + + "Connecting to invalid output 1 of source node W1"); +} + +TEST_F(GraphConstructorTest, GraphWithCycle) { + ExpectError( + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'input:0', 't2' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'input:1', 't1' ] }", + + "cycle"); +} + +TEST_F(GraphConstructorTest, TypeMismatch) { + ExpectError( + "node { name: 'input' op: 'TestInput' }" + "node { name: 'int' op: 'TestInt' input: [ 'input' ] }", + + "Input 0 of node int was passed float from input:0 incompatible with " + "expected int32."); +} + +TEST_F(GraphConstructorTest, EmptyGraph) { ExpectOK(""); } + +TEST_F(GraphConstructorTest, SimpleModel) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"); + EXPECT_TRUE(HasNode("W1")); + EXPECT_TRUE(HasNode("input")); + EXPECT_TRUE(HasNode("t1")); + EXPECT_TRUE(HasEdge("W1", 0, "t1", 0)); + EXPECT_TRUE(HasEdge("input", 1, "t1", 0)); +} + +TEST_F(GraphConstructorTest, SimpleModelWithControlEdges) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' input: [ '^W1' ] }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W1', 'input:1', '^t1' ] }"); + EXPECT_TRUE(HasNode("W1")); + EXPECT_TRUE(HasNode("input")); + EXPECT_TRUE(HasNode("t1")); + EXPECT_TRUE(HasNode("t2")); + EXPECT_TRUE(HasEdge("W1", 0, "t1", 0)); + EXPECT_TRUE(HasEdge("input", 1, "t1", 0)); + EXPECT_TRUE(HasEdge("W1", 0, "t2", 0)); + EXPECT_TRUE(HasEdge("input", 1, "t2", 0)); + EXPECT_TRUE(HasControlEdge("W1", "input")); + EXPECT_TRUE(HasControlEdge("t1", "t2")); +} + +TEST_F(GraphConstructorTest, Error_ControlEdgeBeforeRealInput) { + ExpectError( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' input: [ '^W1' ] }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W1', '^t1', 'input:1' ] }", + "Node 't2': Control dependencies must come after regular dependencies"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc new file mode 100644 index 0000000000..979604f948 --- /dev/null +++ b/tensorflow/core/graph/graph_def_builder.cc @@ -0,0 +1,121 @@ +#include "tensorflow/core/graph/graph_def_builder.h" + +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +GraphDefBuilder::Options::Options(Graph* graph, Status* status) + : graph_(graph), status_(status) {} +GraphDefBuilder::Options::~Options() {} + +GraphDefBuilder::Options GraphDefBuilder::Options::WithName( + StringPiece name) const { + return Options(*this).WithNameImpl(name); +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithDevice( + StringPiece device) const { + return Options(*this).WithDeviceImpl(device); +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInput( + Node* control_input) const { + return Options(*this).WithControlInputImpl(control_input); +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs( + gtl::ArraySlice control_inputs) const { + return Options(*this).WithControlInputsImpl(control_inputs); +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl( + StringPiece name) { + name_ = name.ToString(); + return *this; +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl( + StringPiece device) { + device_ = device.ToString(); + return *this; +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl( + Node* control_input) { + control_inputs_.push_back(control_input); + return *this; +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputsImpl( + gtl::ArraySlice control_inputs) { + control_inputs_.insert(control_inputs_.end(), control_inputs.begin(), + control_inputs.end()); + return *this; +} + +Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const { + if (status_.ok()) { + graph_.ToGraphDef(graph_def); + } + return status_; +} + +Status GraphDefBuilder::ToGraph(Graph* graph) const { + if (status_.ok()) { + GraphDef graph_def; + graph_.ToGraphDef(&graph_def); + GraphConstructorOptions opts; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, graph)); + } + return status_; +} + +string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const { + if (name_.empty()) return graph_->NewName(op); + return name_; +} + +Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const { + builder->ControlInputs(control_inputs_); + if (!device_.empty()) builder->Device(device_); + for (const auto& attr : attrs_) { + builder->Attr(attr.first, attr.second); + } + + Node* returned_node; + UpdateStatus(builder->Finalize(graph_, &returned_node)); + return returned_node; +} + +void GraphDefBuilder::Options::UpdateStatus(const Status& status) const { + if (status_ == nullptr) { + TF_CHECK_OK(status); + } else { + status_->Update(status); + } +} + +namespace ops { + +Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, + opts.op_registry()); + return opts.FinalizeBuilder(&node_builder); +} + +Node* UnaryOp(const string& op_name, NodeOut input, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, + opts.op_registry()); + node_builder.Input(input); + return opts.FinalizeBuilder(&node_builder); +} + +Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, + opts.op_registry()); + node_builder.Input(a).Input(b); + return opts.FinalizeBuilder(&node_builder); +} + +} // end namespace ops +} // end namespace tensorflow diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h new file mode 100644 index 0000000000..bb72f9eea6 --- /dev/null +++ b/tensorflow/core/graph/graph_def_builder.h @@ -0,0 +1,181 @@ +#ifndef TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_ +#define TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_ + +#include +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Given a function like: +// namespace ops { +// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { +// if (opts.HaveError()) return nullptr; +// static const string kOpName = "Identity"; +// NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName, +// opts.op_registry()); +// node_builder.Input(input); +// return opts.FinalizeBuilder(&node_builder); +// } +// } // namspace ops +// +// // Or, alternatively: +// namespace ops { +// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { +// static const string kOpName = "Identity"; +// return UnaryOp(kOpName, input, opts); +// } +// } // namspace ops +// +// You call it like: +// GraphDefBuilder b; +// using namespace ::tensorflow::ops; // NOLINT(build/namespaces) +// Node* a = Const(7, b.opts()); +// // Note: WithName() returns a copy, opts is unchanged. +// Node* b = Const(5, b.opts().WithName("control-input")); +// Node* c = Identity(a, b.opts().WithControlInput(b)); +// GraphDef graph_def; +// Status status = b.ToGraphDef(&graph_def); +// if (!status.ok()) { /* Handle error */ } +// +// In tests you can skip the status handling via: +// GraphDefBuilder b(GraphDefBuilder::kFailImmediately); +// ... +// b.ToGraphDef(&graph_def); + +class GraphDefBuilder { + public: + // Options for adding a Node to a Graph. + class Options { + public: + // Sets the Graph (that Nodes will be added to) and the status. The + // status may be set to nullptr, in which case errors cause CHECK + // failures. The graph and status must outlive *this. + Options(Graph* graph, Status* status); + ~Options(); + + // Methods for setting options. These are const methods: they + // return a copy of *this with the option set. + Options WithName(StringPiece name) const; + Options WithDevice(StringPiece device) const; + Options WithControlInput(Node* control_input) const; + Options WithControlInputs(gtl::ArraySlice control_inputs) const; + + // Override the default value for an optional attr. + template + Options WithAttr(StringPiece attr_name, T&& value) const { + return Options(*this).WithAttrImpl(attr_name, std::forward(value)); + } + // Note: overload needed to allow {...} expressions for value. + template + Options WithAttr(StringPiece attr_name, + std::initializer_list value) const { + return WithAttr>(attr_name, std::move(value)); + } + + // Methods for using options from a function that creates a Node. + + // Returns true if the status associated with *this has an error. + // Use this to skip processing that may depend on prior results. + bool HaveError() const { return status_ != nullptr && !status_->ok(); } + + // Given the Op type name, return a name for a node of that type. + // Uses the value set in WithName() if that has been called. Otherwise, + // returns a name built out of the Op type name. + string GetNameForOp(StringPiece op) const; + + // Sets the device, adds control inputs, adds attrs, and calls Finalize(). + // If Finalize returns an error, it is saved and this function returns + // nullptr. + Node* FinalizeBuilder(NodeBuilder* builder) const; + + // Updates the associated status, if any, or calls TF_CHECK_OK if none. + void UpdateStatus(const Status& status) const; + + // Accessor + const OpRegistryInterface* op_registry() const { + return graph_->op_registry(); + } + + private: + Options WithNameImpl(StringPiece name); + Options WithDeviceImpl(StringPiece device); + Options WithControlInputImpl(Node* control_input); + Options WithControlInputsImpl(gtl::ArraySlice control_inputs); + template + Options WithAttrImpl(StringPiece name, T&& value) { + attrs_.emplace_back(name.ToString(), AttrValue()); + SetAttrValue(std::forward(value), &attrs_.back().second); + return *this; + } + + Graph* const graph_; + Status* const status_; + string name_; + string device_; + std::vector control_inputs_; + std::vector> attrs_; + }; + + // Start building a new graph. + explicit GraphDefBuilder( + const OpRegistryInterface* op_registry = OpRegistry::Global()) + : graph_(op_registry), opts_(&graph_, &status_) {} + + // For use in tests, where you want to fail immediately on error instead + // of checking the status at the end. + enum TestFailImmediatelyType { kFailImmediately }; + explicit GraphDefBuilder( + TestFailImmediatelyType, + const OpRegistryInterface* op_registry = OpRegistry::Global()) + : graph_(op_registry), opts_(&graph_, nullptr) {} + + // Gets the Options with the associated Graph and Status. + const Options& opts() const { return opts_; } + + // Once all the nodes have been added, call this to get whether it was + // successful, and if so fill *graph_def. + Status ToGraphDef(GraphDef* graph_def) const; + + // Like ToGraphDef(), but converts to a Graph (using the default + // GraphConstructorOptions). + // TODO(josh11b): Make this faster; right now it converts + // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds + // edges from the source and to the sink node, resolves back edges + // by name), and makes sure the resulting graph is valid. + Status ToGraph(Graph* graph) const; + + private: + Graph graph_; + Status status_; + Options opts_; +}; + +namespace ops { + +// A NodeOut may either be a regular input or back input. Regular +// inputs are specified via either a Node* or a Node* and an output +// index. Back inputs are specified by a node name, output index, and +// output type. +typedef NodeBuilder::NodeOut NodeOut; + +// For adding an Op with no inputs to a GraphDefBuilder. +Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts); + +// For adding an Op with one input to a GraphDefBuilder. +Node* UnaryOp(const string& op_name, NodeOut input, + const GraphDefBuilder::Options& opts); + +// For adding an Op with two inputs to a GraphDefBuilder. +Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, + const GraphDefBuilder::Options& opts); + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_ diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc new file mode 100644 index 0000000000..1571790e59 --- /dev/null +++ b/tensorflow/core/graph/graph_partition.cc @@ -0,0 +1,1050 @@ +#include "tensorflow/core/graph/graph_partition.h" + +#include +#include + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/costmodel.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace { + +struct DupRecvKey { + int src_node_id; // Edge's src node id + int src_output_slot; // Edge's src node output slot + GraphDef* dst_graph; // Edge's dst node is in this subgraph + bool recv_output_on_host; // The output of recv is on host +}; + +struct DupRecvKeyHash { + size_t operator()(const DupRecvKey& k) const { + size_t h = Hash64(reinterpret_cast(&k.src_node_id), + sizeof(k.src_node_id), k.src_output_slot); + h = Hash64(reinterpret_cast(&k.dst_graph), sizeof(k.dst_graph), + h); + h = Hash64(reinterpret_cast(&k.recv_output_on_host), + sizeof(k.recv_output_on_host), h); + return h; + } +}; + +struct DupRecvKeyEq { + bool operator()(const DupRecvKey& x, const DupRecvKey& y) const { + return (x.src_node_id == y.src_node_id) && + (x.src_output_slot == y.src_output_slot) && + (x.dst_graph == y.dst_graph) && + (x.recv_output_on_host == y.recv_output_on_host); + } +}; + +// struct used to store the recvs, so that start times can be properly updated +struct RecvInfo { + NodeDef* recv; + NodeDef* real_recv; + int64 start_time; +}; + +typedef std::unordered_map + DupRecvTable; + +// Control flow info for a graph node. +struct ControlFlowInfo { + const Node* frame = nullptr; // frame of a node + const Node* parent_frame = nullptr; // parent frame of a node + string frame_name; // frame name of a node + int iter_level = -1; // level of a node +}; + +struct PairIntHash { + public: + std::size_t operator()(const std::pair& x) const { + return std::hash()(x.first) ^ std::hash()(x.second); + } +}; +// A map used to store memory types for the inputs/outputs of every node. +// The key is a pair of ints consisting of a node id and input/output index. +typedef std::unordered_map, MemoryType, PairIntHash> + MemoryTypeMap; + +// We collect the following information about the graph before performing +// graph partitioning. +struct GraphInfo { + std::vector device_types; + MemoryTypeMap input_types; + MemoryTypeMap output_types; + std::vector cf_info; +}; + +DataType EdgeType(const Edge* e) { + if (e->IsControlEdge()) { + return DT_FLOAT; + } else { + return e->dst()->input_type(e->dst_input()); + } +} + +// Return true iff we need to add a same device send/recv for 'edge'. +bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) { + if (edge->IsControlEdge()) { + return false; + } + + Node* src = edge->src(); + Node* dst = edge->dst(); + if (src->assigned_device_name() == dst->assigned_device_name()) { + int src_port = edge->src_output(); + int dst_port = edge->dst_input(); + if (info.device_types[src->id()] == DEVICE_GPU) { + auto src_it = info.output_types.find({src->id(), src_port}); + DCHECK(src_it != info.output_types.end()); + auto dst_it = info.input_types.find({dst->id(), dst_port}); + DCHECK(dst_it != info.input_types.end()); + return src_it->second != dst_it->second; + } + } + return false; +} + +// Return true iff (dst, dst_input) is specified on host memory. +bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) { + Node* dst = edge->dst(); + int dst_port = edge->dst_input(); + if (info.device_types[dst->id()] == DEVICE_GPU) { + if (edge->IsControlEdge()) return false; + auto dst_it = info.input_types.find({dst->id(), dst_port}); + DCHECK(dst_it != info.input_types.end()); + return dst_it->second == HOST_MEMORY; + } + return true; +} + +// Add an input to dst that comes from the "src_slot" output of the +// node named by "src_name". +void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { + if (src_slot == Graph::kControlSlot) { + dst->add_input(strings::StrCat("^", src_name)); + } else if (src_slot == 0) { + dst->add_input(src_name.data(), src_name.size()); + } else { + dst->add_input(strings::StrCat(src_name, ":", src_slot)); + } +} + +// Add a control edge from each input to each recv. +void AddReadControl(const std::vector& recvs, + const std::vector& inputs) { + for (NodeDef* recv : recvs) { + for (const string& input : inputs) { + recv->add_input(strings::StrCat("^", input)); + } + } +} + +void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge, + NodeDefBuilder* builder) { + builder->Attr("tensor_name", + strings::StrCat("edge_", edge->id(), "_", edge->src()->name())); + builder->Attr("send_device", edge->src()->assigned_device_name()); + builder->Attr("send_device_incarnation", + static_cast( + opts.get_incarnation(edge->src()->assigned_device_name()))); + builder->Attr("recv_device", edge->dst()->assigned_device_name()); + builder->Attr("client_terminated", false); +} + +NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, + GraphDef* gdef, const Edge* edge, + NodeDefBuilder::NodeOut send_from, int64 start_time, + Status* status) { + const DataType dtype = send_from.data_type; + const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype; + const Node* src = edge->src(); + const int src_port = edge->src_output(); + + // host_memory = true iff we need to use HostSend/HostCast. + bool host_memory = false; + if (!edge->IsControlEdge()) { + auto src_it = g_info.output_types.find({src->id(), src_port}); + DCHECK(src_it != g_info.output_types.end()); + host_memory = (src_it->second == HOST_MEMORY); + } + + // Add a cast node that casts dtype to cast_dtype. + // NOTE(yuanbyu): Only cast for cross-device send/recv. + if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) { + const string cast_op = (host_memory) ? "_HostCast" : "Cast"; + NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op); + cast_builder.Device(src->assigned_device_name()).Input(send_from); + if (opts.scheduling_for_recvs) { + cast_builder.Attr("_start_time", start_time); + } + cast_builder.Attr("DstT", cast_dtype); + NodeDef* cast = gdef->add_node(); + *status = cast_builder.Finalize(cast); + if (!status->ok()) return nullptr; + + // Connect the Send op to the cast. + send_from.Reset(cast->name(), 0, cast_dtype); + } + + // Add the send node. + const string send_op = (host_memory) ? "_HostSend" : "_Send"; + NodeDefBuilder send_builder(opts.new_name(src->name()), send_op); + SetSendRecvAttrs(opts, edge, &send_builder); + send_builder.Device(src->assigned_device_name()).Input(send_from); + if (opts.scheduling_for_recvs) { + send_builder.Attr("_start_time", start_time); + } + NodeDef* send = gdef->add_node(); + *status = send_builder.Finalize(send); + return send; +} + +NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info, + GraphDef* gdef, const Edge* edge, NodeDef** real_recv, + Status* status) { + const DataType dtype = EdgeType(edge); + const Node* src = edge->src(); + const Node* dst = edge->dst(); + const int dst_port = edge->dst_input(); + DataType cast_dtype = dtype; + + // NOTE(yuanbyu): Only cast for cross-device send/recv. + if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) { + cast_dtype = opts.should_cast(edge); + } + + // host_memory = true iff we need to use HostRecv/HostCast. + bool host_memory = false; + if (!edge->IsControlEdge()) { + auto dst_it = g_info.input_types.find({dst->id(), dst_port}); + DCHECK(dst_it != g_info.input_types.end()); + host_memory = (dst_it->second == HOST_MEMORY); + } + + // Add the recv node. + const string recv_op = (host_memory) ? "_HostRecv" : "_Recv"; + NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op); + SetSendRecvAttrs(opts, edge, &recv_builder); + recv_builder.Device(dst->assigned_device_name()) + .Attr("tensor_type", cast_dtype); + NodeDef* recv = gdef->add_node(); + *status = recv_builder.Finalize(recv); + if (!status->ok()) return nullptr; + *real_recv = recv; + + // Add the cast node (from cast_dtype to dtype) or an Identity node. + if (dtype != cast_dtype) { + const string cast_op = (host_memory) ? "_HostCast" : "Cast"; + NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op); + cast_builder.Attr("DstT", dtype); + cast_builder.Device(dst->assigned_device_name()) + .Input(recv->name(), 0, cast_dtype); + NodeDef* cast = gdef->add_node(); + *status = cast_builder.Finalize(cast); + if (!status->ok()) return nullptr; + return cast; + } else if (edge->IsControlEdge()) { + // An Identity is only needed for control edges. + NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity"); + id_builder.Device(dst->assigned_device_name()) + .Input(recv->name(), 0, cast_dtype); + NodeDef* id = gdef->add_node(); + *status = id_builder.Finalize(id); + if (!status->ok()) return nullptr; + return id; + } else { + return recv; + } +} + +NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef, + const Edge* edge, Status* status) { + const Node* src = edge->src(); + Tensor tensor(DT_FLOAT, TensorShape({0})); + NodeDef* result = gdef->add_node(); + *status = NodeDefBuilder(opts.new_name(src->name()), "Const") + .Device(src->assigned_device_name()) + .Attr("dtype", DT_FLOAT) + .Attr("value", tensor) + .Finalize(result); + return result; +} + +// A dummy node for scheduling. +NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef, + const string& assigned_device_name, int64 epoch, + int64 starttime, Status* status) { + NodeDef* result = gdef->add_node(); + *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)), + "ControlTrigger") + .Device(assigned_device_name) + .Attr("_start_time", starttime) + .Finalize(result); + return result; +} + +// Assign to each node the name of the frame and the level it belongs to. +// We check the well-formedness of the graph: All inputs to a node must +// come from the same frame and have the same "static" iteration level. +// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level +// 0. This essentially means there can't be multiple serial Nexts in +// an iteration, which all sane front-ends should satisfy. +Status BuildControlFlowInfo(Graph* g, std::vector* info) { + info->clear(); + info->resize(g->num_node_ids()); + + Node* src_node = g->source_node(); + ControlFlowInfo& src_info = (*info)[src_node->id()]; + src_info.frame = src_node; + src_info.parent_frame = src_node; + src_info.iter_level = 0; + + string frame_name; + std::deque ready; + ready.push_back(src_node); + while (!ready.empty()) { + const Node* curr_node = ready.front(); + ready.pop_front(); + const ControlFlowInfo& curr_info = (*info)[curr_node->id()]; + const Node* frame = curr_info.frame; + const Node* parent = curr_info.parent_frame; + frame_name = curr_info.frame_name; + int iter_level = curr_info.iter_level; + + if (IsExit(curr_node)) { + const ControlFlowInfo& parent_info = (*info)[parent->id()]; + frame = parent_info.frame; + parent = parent_info.parent_frame; + frame_name = parent_info.frame_name; + iter_level = parent_info.iter_level; + } + + for (const Edge* out_edge : curr_node->out_edges()) { + const Node* out = out_edge->dst(); + int out_id = out->id(); + ControlFlowInfo* out_info = &(*info)[out_id]; + const Node* out_parent = out_info->parent_frame; + bool is_visited = (out_info->iter_level != -1); + + // Skip Sink/Source nodes. + if (!out->IsOp()) continue; + + // Add to ready queue if not seen. + if (!is_visited) { + ready.push_back(out); + } + + // Process the node 'out'. + if (IsEnter(out)) { + if (is_visited) { + const string& parent_name = (*info)[out_parent->id()].frame_name; + if (parent_name != frame_name || iter_level != out_info->iter_level) { + return errors::InvalidArgument( + "All inputs to Enter must be from the same frame and level."); + } + } else { + out_info->frame = out; + out_info->parent_frame = frame; + TF_RETURN_IF_ERROR( + GetNodeAttr(out->def(), "frame_name", &out_info->frame_name)); + if (out_info->frame_name.empty()) { + return errors::InvalidArgument( + "Enter must have a non-empty frame name."); + } + out_info->iter_level = 0; + } + } else if (IsNextIteration(out)) { + if (is_visited) { + if (out_info->frame_name != frame_name || + out_info->iter_level != (iter_level + 1)) { + return errors::InvalidArgument( + "All inputs to NextIteration must be from the same frame " + "and level."); + } + } else { + out_info->frame = frame; + out_info->parent_frame = parent; + out_info->frame_name = frame_name; + out_info->iter_level = iter_level + 1; + } + } else { + if (is_visited) { + if (out_info->frame_name != frame_name) { + return errors::InvalidArgument( + "All inputs to a node must be from the same frame."); + } + } else { + out_info->frame = frame; + out_info->parent_frame = parent; + out_info->frame_name = frame_name; + out_info->iter_level = iter_level; + } + } + } + } + + return Status::OK(); +} + +string ControlLoopName(const string& name) { + return strings::StrCat("_cloop", name); +} + +bool IsControlLoop(const Node* node) { + const string& name = node->def().name(); + return StringPiece(name).starts_with("_cloop"); +} + +// An enter node for control flow. +Node* AddControlEnter(Graph* g, const string& node_name, + const string& device_name, const string& frame_name, + const int parallel_iterations, Status* status) { + NodeBuilder node_builder(node_name, "Enter", g->op_registry()); + node_builder.Input({"dummy", 0, DT_FLOAT}); + node_builder.Attr("frame_name", frame_name); + node_builder.Attr("parallel_iterations", parallel_iterations); + Node* res_node; + *status = node_builder.Finalize(g, &res_node); + if (!status->ok()) return nullptr; + res_node->set_assigned_device_name(device_name); + return res_node; +} + +// A merge node for control flow. +Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g, + const string& node_name, const string& device_name, + Status* status) { + NodeBuilder node_builder(node_name, "Merge", g->op_registry()); + node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}}); + Node* res_node; + *status = node_builder.Finalize(g, &res_node); + if (!status->ok()) return nullptr; + res_node->set_assigned_device_name(device_name); + return res_node; +} + +// A switch node for control flow. +Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2, + const string& device_name, + const GraphDefBuilder::Options& bopts) { + Node* res_node = ops::BinaryOp("Switch", input1, input2, bopts); + if (bopts.HaveError()) return nullptr; + res_node->set_assigned_device_name(device_name); + return res_node; +} + +// A next_iteration node for control flow. +Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name, + const GraphDefBuilder::Options& bopts) { + Node* res_node = ops::UnaryOp("NextIteration", input, bopts); + if (bopts.HaveError()) return nullptr; + res_node->set_assigned_device_name(device_name); + return res_node; +} + +Node* EmptyConst(const GraphDefBuilder::Options& options) { + if (options.HaveError()) return nullptr; + NodeBuilder node_builder(options.GetNameForOp("Const"), "Const", + options.op_registry()); + const DataType dt = DataTypeToEnum::v(); + TensorProto proto; + proto.set_dtype(dt); + TensorShape empty_shape({0}); + empty_shape.AsProto(proto.mutable_tensor_shape()); + node_builder.Attr("dtype", dt).Attr("value", proto); + return options.FinalizeBuilder(&node_builder); +} + +// A dummy const node for control flow. +Node* AddControlConst(const string& device_name, + const GraphDefBuilder::Options& bopts) { + Node* res_node = EmptyConst(bopts); + if (bopts.HaveError()) return nullptr; + res_node->set_assigned_device_name(device_name); + return res_node; +} + +// A synthetic loop, made up of dummy nodes. It performs control-flow actions +// on behalf of a leader on a different device. +struct ControlLoop { + Node* enter = nullptr; + Node* merge = nullptr; + Node* switch_node = nullptr; +}; + +// Add the control flow info of a new node added during partitioning. +// The new node has the same control flow info as edge->src(). +void AddControlFlowInfo(const Node* node, const Node* src, + std::vector* cf_info) { + int id = node->id(); + if (static_cast(id) >= cf_info->size()) { + cf_info->resize(id + 1); + } + const ControlFlowInfo& src_info = (*cf_info)[src->id()]; + ControlFlowInfo* info = &(*cf_info)[id]; + info->frame = src_info.frame; + info->parent_frame = src_info.parent_frame; + info->frame_name = src_info.frame_name; + info->iter_level = src_info.iter_level; +} + +// Constructs a control loop. Returns a struct containing the newly created +// enter, merge, and switch nodes. The enter and merge nodes are used in the +// recursive construction of control loops for nested frames (loops). The +// switch node will be connected to the LoopCond node. The merge node will +// be connected to all the recvs of the same frame by control edges when +// the actual partitioning happens. +Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src, + const Edge* edge, Node* loop_cond, + std::vector* cf_info, + ControlLoop* loop) { + Status status; + GraphDefBuilder::Options bopts(g, &status); + const ControlFlowInfo& src_info = (*cf_info)[src->id()]; + const string& device_name = edge->dst()->assigned_device_name(); + const string& frame_name = src_info.frame_name; + int parallel_iterations; + status = GetNodeAttr(src_info.frame->def(), "parallel_iterations", + ¶llel_iterations); + if (!status.ok()) return status; + + // The names of the nodes to be added. + const string& enter_name = + ControlLoopName(opts.new_name(edge->dst()->name())); + const string& merge_name = + ControlLoopName(opts.new_name(edge->dst()->name())); + const string& switch_name = + ControlLoopName(opts.new_name(edge->dst()->name())); + const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name())); + + // Add the nodes to the graph g. + Node* enter = AddControlEnter(g, enter_name, device_name, frame_name, + parallel_iterations, &status); + if (!status.ok()) return status; + Node* merge = AddControlMerge(enter_name, next_name, g, merge_name, + device_name, &status); + if (!status.ok()) return status; + Node* switch_node = AddControlSwitch(merge, loop_cond, device_name, + bopts.WithName(switch_name)); + if (!status.ok()) return status; + Node* next = + AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name)); + if (!status.ok()) return status; + + // Add control flow info for these new nodes: + AddControlFlowInfo(enter, src, cf_info); + AddControlFlowInfo(merge, src, cf_info); + AddControlFlowInfo(switch_node, src, cf_info); + AddControlFlowInfo(next, src, cf_info); + + // Add input edges for the newly created merge node: + g->AddEdge(enter, 0, merge, 0); + g->AddEdge(next, 0, merge, 1); + + loop->enter = enter; + loop->merge = merge; + loop->switch_node = switch_node; + return Status::OK(); +} + +// Build memory and device type info for every node in the graph. +// TODO(yuanbyu): It might be simpler if we convert MemoryType to +// DeviceType for the inputs/outputs of each node. +Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) { + Status status; + MemoryTypeVector input_memory_types; + MemoryTypeVector output_memory_types; + + info->device_types.resize(g.num_node_ids(), DEVICE_CPU); + for (const Node* node : g.nodes()) { + if (!node->IsOp()) continue; // Skip Sink/Source nodes. + + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(), + &parsed)) { + return errors::Internal("Malformed assigned device '", + node->assigned_device_name(), "'"); + } + + input_memory_types.clear(); + input_memory_types.resize(node->num_inputs()); + output_memory_types.clear(); + output_memory_types.resize(node->num_outputs()); + status = MemoryTypesForNode(g.op_registry(), DeviceType(parsed.type), + node->def(), &input_memory_types, + &output_memory_types); + if (!status.ok()) return status; + + int node_id = node->id(); + info->device_types[node_id] = DeviceType(parsed.type); + for (size_t i = 0; i < input_memory_types.size(); ++i) { + info->input_types[{node_id, i}] = input_memory_types[i]; + } + for (size_t i = 0; i < output_memory_types.size(); ++i) { + info->output_types[{node_id, i}] = output_memory_types[i]; + } + } + return status; +} + +// Each participating device needs to decide a) if there is a next iteration, +// and b) if the loop terminates. We take the approach to encode this control +// flow logic in the dataflow graph. There are at least two possible encodings. +// In a completely decentralized encoding, the participants communicate peer +// to peer. The other encoding uses a frame leader (the participant who owns +// the pivot termination predicate) to broadcast the termination condition to +// all the participants. For now we take the latter because it is simpler. +// +// TODO(yuanbyu): The correctness of this construction is rather subtle. I got +// it wrong many times so it would be nice to write a proof to be sure. +Status AddControlFlow(const PartitionOptions& opts, Graph* g, + GraphInfo* g_info) { + Status status; + GraphDefBuilder::Options bopts(g, &status); + std::vector& cf_info = g_info->cf_info; + + // Build the control flow info for every node. + status = BuildControlFlowInfo(g, &cf_info); + if (!status.ok()) return status; + + // The map from frames to their LoopCond nodes. + std::unordered_map frame_cond_map; + int num_node_ids = g->num_node_ids(); + for (int i = 0; i < num_node_ids; ++i) { + Node* node = g->FindNodeId(i); + if (node == nullptr) continue; + + if (IsLoopCond(node)) { + const string& frame_name = cf_info[node->id()].frame_name; + DCHECK(!frame_name.empty()); + frame_cond_map[frame_name] = node; + } + } + + // Add all control loops for cross-device frames. + // A control loop is added only when there is a cross-device edge in a + // non-root frame. Nothing is added if there is no loops. We also don't + // add anything for a frame that is completely local to a device. For + // nested loops, we stack the control loops together by connecting + // the merge of the outer loop to the enter of the inner loop. + // + // A map from to ControlLoop. + std::unordered_map control_loops; + int num_edge_ids = g->num_edge_ids(); + for (int i = 0; i < num_edge_ids; ++i) { + const Edge* edge = g->FindEdgeId(i); + if (edge == nullptr) continue; + + const Node* src = edge->src(); + const Node* dst = edge->dst(); + // Skip Sink/Source nodes. + if (!src->IsOp() || !dst->IsOp()) continue; + + const string& src_device = src->assigned_device_name(); + const string& dst_device = dst->assigned_device_name(); + // Skip local edges. + if (src_device == dst_device) continue; + + const string& src_frame = cf_info[src->id()].frame_name; + const string& dst_frame = cf_info[dst->id()].frame_name; + // Skip if src and dst are not in the same frame. + if (src_frame.empty() || src_frame != dst_frame) { + continue; + } + + // Add the control loop. Start by adding the control loop for the + // current frame if needed, and recursively adding the control loop + // for its outer frame when nested. + ControlLoop child_loop; + while (true) { + const string& curr_frame = cf_info[src->id()].frame_name; + if (curr_frame.empty()) { + // We have reached the root frame. + if (child_loop.merge != nullptr) { + const string& node_name = opts.new_name(edge->dst()->name()); + const string& device_name = edge->dst()->assigned_device_name(); + Node* const_node = + AddControlConst(device_name, bopts.WithName(node_name)); + if (!status.ok()) return status; + AddControlFlowInfo(const_node, src, &cf_info); + g->AddEdge(const_node, 0, child_loop.enter, 0); + } + break; + } + + const string& cl_key = strings::StrCat(curr_frame, "$$", dst_device); + auto it = control_loops.find(cl_key); + if (it != control_loops.end()) { + if (child_loop.enter != nullptr) { + g->AddEdge(it->second.merge, 0, child_loop.enter, 0); + } + break; + } + + // Get the frame's LoopCond. + auto cond_it = frame_cond_map.find(curr_frame); + if (cond_it == frame_cond_map.end()) { + return errors::InvalidArgument( + "A cross-device loop must have a pivot predicate: ", curr_frame); + } + Node* loop_cond = cond_it->second; + + // Add the control loop. + ControlLoop curr_loop; + status = + AddControlLoop(opts, g, src, edge, loop_cond, &cf_info, &curr_loop); + if (!status.ok()) return status; + control_loops[cl_key] = curr_loop; + + if (child_loop.enter != nullptr) { + // Connect the merge of the outer loop to the enter of the inner. + g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0); + } + src = cf_info[src->id()].parent_frame; + child_loop = curr_loop; + } + } + + // For a cross-device edge, on the dst device, add a control edge + // from the merge node of the control loop to dst. If a send/recv is + // introduced for this edge in future partitioning, we delete this + // control edge and add a new control edge from the merge to the recv. + num_edge_ids = g->num_edge_ids(); + for (int i = 0; i < num_edge_ids; ++i) { + const Edge* edge = g->FindEdgeId(i); + if (edge == nullptr) continue; + + const Node* src = edge->src(); + Node* dst = edge->dst(); + // Skip Sink/Source nodes. + if (!src->IsOp() || !dst->IsOp()) continue; + + const string& src_device = src->assigned_device_name(); + const string& dst_device = dst->assigned_device_name(); + if (src_device != dst_device) { + const string& src_frame = cf_info[src->id()].frame_name; + const string& dst_frame = cf_info[dst->id()].frame_name; + if (!src_frame.empty() && src_frame == dst_frame) { + const string& cl_key = strings::StrCat(dst_frame, "$$", dst_device); + ControlLoop loop = control_loops[cl_key]; + DCHECK(loop.enter != nullptr); + g->AddControlEdge(loop.merge, dst); + } + } + } + return Status::OK(); +} + +} // end namespace + +Status AddControlEdges(const PartitionOptions& opts, + std::unordered_map* partitions) { + Status status; + // TODO(yuanbyu): Very naive for now. To be improved. + const int num_epochs = 100; + const int prefetch = 6; + + typedef std::pair NodeStartTime; + for (auto& part : *partitions) { + GraphDef* gdef = &part.second; + + std::vector start_times; + start_times.resize(gdef->node_size()); + for (int n = 0; n < gdef->node_size(); ++n) { + const NodeDef& ndef = gdef->node(n); + int64 start_time; + status = GetNodeAttr(ndef, "_start_time", &start_time); + if (!status.ok()) { + return status; + } + start_times[n] = std::make_pair(&ndef, start_time); + } + + // Sort the nodes based on their start times. + std::sort( + start_times.begin(), start_times.end(), + [](NodeStartTime x, NodeStartTime y) { return x.second < y.second; }); + + // Add a dummy node for every epoch, and add a control edge from the + // "last" node in the preceding epoch to the dummy node. + string device_name = gdef->node(0).device(); + int64 makespan = start_times.back().second; + int64 resolution = (makespan / num_epochs) + 1; + + int i = 0; + int j = 0; + std::vector dummys; + while (i < num_epochs && static_cast(j) < start_times.size()) { + if (i * resolution > start_times[j].second) { + j++; + } else { + NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i, + i * resolution, &status); + if (!status.ok()) { + return status; + } + dummys.push_back(dummy); + if (j > 0) { + string src_name = start_times[j - 1].first->name(); + AddInput(dummy, src_name, Graph::kControlSlot); + } + i++; + } + } + + // Finally, add the control edges to recvs. + for (int n = 0; n < gdef->node_size(); ++n) { + NodeDef* ndef = gdef->mutable_node(n); + if (ndef->op() == "_Recv") { + int64 start_time; + status = GetNodeAttr(*ndef, "_start_time", &start_time); + if (!status.ok()) { + return status; + } + int recv_epoch = start_time / resolution; + if (recv_epoch >= prefetch) { + NodeDef* dummy = dummys[recv_epoch - prefetch]; + AddInput(ndef, dummy->name(), Graph::kControlSlot); + } + } + } + } + return Status::OK(); +} + +Status Partition(const PartitionOptions& opts, Graph* g, + std::unordered_map* partitions) { + Status status; + partitions->clear(); + + GraphInfo g_info; + if (!opts.control_flow_added) { + // Add the "code" for distributed execution of control flow. Code is + // added only for the frames that are placed on multiple devices. The + // new graph is an equivalent transformation of the original graph and + // has the property that it can be subsequently partitioned arbitrarily + // (down to the level of individual device) for distributed execution. + status = AddControlFlow(opts, g, &g_info); + if (!status.ok()) return status; + } + // At this point, all the graph mutations have been done. Build memory + // and device type info for every node and edge in the graph. + status = BuildMemoryDeviceInfo(*g, &g_info); + if (!status.ok()) return status; + + string dstp; + std::vector inputs; + DupRecvTable dup_recv(3); + // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref + // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref + // edge to dst. We will add a control edge for every pair in + // (ref_recvs x ref_control_inputs). + std::vector ref_recvs; + std::vector ref_control_inputs; + + int32 num_data = 0; + int32 num_control = 0; + for (const Node* dst : g->nodes()) { + if (!dst->IsOp()) continue; // Skip Sink/Source nodes. + + dstp = opts.node_to_loc(dst); + GraphDef* dst_graph = &(*partitions)[dstp]; + NodeDef* dst_def = dst_graph->add_node(); + *dst_def = dst->def(); + dst_def->set_device(dst->assigned_device_name()); + dst_def->clear_input(); // Inputs are filled below + if (opts.need_to_record_start_times) { + int64 start_time = opts.start_times[dst->id()].value(); + AddNodeAttr("_start_time", start_time, dst_def); + } + + // Arrange the incoming edges to dst so that input[i] holds the + // input flowing into slot numbered i. Trailing entries in input[] + // hold control edges. + inputs.clear(); + inputs.resize(dst->num_inputs(), nullptr); + ref_recvs.clear(); + ref_control_inputs.clear(); + const Edge* control_flow_edge = nullptr; + for (const Edge* edge : dst->in_edges()) { + if (edge->IsControlEdge()) { + if (IsMerge(edge->src()) && IsControlLoop(edge->src())) { + // This is one of the control edges added for control flow. There + // can be multiple such edges as the dest node may have multiple + // remote inputs. We will just take one and ignore the others. + control_flow_edge = edge; + } else { + inputs.push_back(edge); + } + } else { + DCHECK(inputs[edge->dst_input()] == nullptr); + inputs[edge->dst_input()] = edge; + } + } + + // Process in order so that all data edges are added as inputs to + // dst in Edge::dst_input() order. + bool recv_added = false; + for (const Edge* edge : inputs) { + const Node* src = edge->src(); + if (!src->IsOp()) continue; // Skip Sink/Source nodes. + + GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)]; + if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) { + // Same partition and compatible memory types: + AddInput(dst_def, src->name(), edge->src_output()); + if (edge->IsControlEdge() || + !IsRefType(src->output_type(edge->src_output()))) { + ref_control_inputs.push_back(src->name()); + } + continue; + } + + int64 send_start_time = 0; + int64 recv_start_time = 0; + if (opts.scheduling_for_recvs) { + if (opts.need_to_record_start_times) { + send_start_time = opts.start_times[src->id()].value(); + recv_start_time = opts.start_times[dst->id()].value(); + } else { + status = GetNodeAttr(src->def(), "_start_time", &send_start_time); + if (!status.ok()) { + return status; + } + status = GetNodeAttr(dst->def(), "_start_time", &recv_start_time); + if (!status.ok()) { + return status; + } + } + } + + // Check whether there is already a send/recv pair transferring + // the same tensor/control from the src to dst partition. + const bool on_host = IsDstInputOnHost(edge, g_info); + DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host}; + auto iter = dup_recv.find(key); + if (iter != dup_recv.end()) { + // We found one. Reuse the data/control transferred already. + const string& recv_node_name = iter->second.recv->name(); + if (edge->IsControlEdge()) { + AddInput(dst_def, recv_node_name, Graph::kControlSlot); + } else { + AddInput(dst_def, recv_node_name, 0); + } + // We want the start_time for the recv to be the smallest of the start + // times of it's consumers. So we update this whenever we use a recv, + // and write it out to the attribute at the end of the subroutine + if (iter->second.start_time > recv_start_time) { + iter->second.start_time = recv_start_time; + } + continue; + } + + NodeDefBuilder::NodeOut send_from; + if (edge->IsControlEdge()) { + // Insert a dummy const node that will generate a tiny + // data element to be sent from send to recv. + VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "[" + << src->name() << "] -> " << dst->assigned_device_name() << "[" + << dst->name() << "]"; + NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status); + if (!status.ok()) return status; + // Set the start time for this dummy node. + if (opts.scheduling_for_recvs) { + AddNodeAttr("_start_time", send_start_time, dummy); + } + AddInput(dummy, src->name(), Graph::kControlSlot); + send_from.Reset(dummy->name(), 0, DT_FLOAT); + } else { + send_from.Reset(src->name(), edge->src_output(), EdgeType(edge)); + } + + // Need to split edge by placing matching send/recv nodes on + // the src/dst sides of the edge. + NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from, + send_start_time, &status); + if (!status.ok()) return status; + + NodeDef* real_recv = nullptr; + NodeDef* recv = + AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status); + if (!status.ok()) return status; + + // Fix up the control flow edge. Redirect it to the recv. + // NOTE(yuanbyu): 'real_recv' must be the real recv node. + recv_added = true; + if (control_flow_edge != nullptr) { + AddInput(real_recv, control_flow_edge->src()->name(), + Graph::kControlSlot); + } + + // For same device send/recv, add a control edge from send to recv. + // This prevents the asynchronous recv kernel from being scheduled + // immediately. + if (src_graph == dst_graph) { + AddInput(real_recv, send->name(), Graph::kControlSlot); + } + + if (!edge->IsControlEdge() && + IsRefType(src->output_type(edge->src_output()))) { + // If src is of ref type and the edge is not a control edge, dst has + // read semantics and therefore we must control the recv. + ref_recvs.push_back(real_recv); + } else { + // Memorize the send/recv pair, only if this is not a "ref" edge. + // NOTE(yuanbyu): Collapsing ref edges requires extreme care so + // for now we don't do it. + dup_recv[key] = {recv, real_recv, recv_start_time}; + ref_control_inputs.push_back(recv->name()); + } + + if (edge->IsControlEdge()) { + ++num_control; + AddInput(dst_def, recv->name(), Graph::kControlSlot); + } else { + ++num_data; + AddInput(dst_def, recv->name(), 0); + } + } + + // Add control edges from 'ref_control_inputs' to 'ref_recvs'. + // NOTE(yuanbyu): Adding these control edges should not introduce + // deadlocks. 'dst' has implicit "read" nodes that, when we split + // across devices, are made explicit; Retargettig the dependencies + // to 'dst' to those nodes would not introduce cycles if there isn't + // one before the transformation. + // NOTE(yuanbyu): This may impact performance because it defers the + // execution of recvs until all the other inputs become available. + AddReadControl(ref_recvs, ref_control_inputs); + + // Add back this control edge for control flow if not used. + if (!recv_added && (control_flow_edge != nullptr)) { + AddInput(dst_def, control_flow_edge->src()->name(), Graph::kControlSlot); + } + } + + // Set the start times for recvs at the very end. + if (opts.scheduling_for_recvs) { + for (auto& it : dup_recv) { + AddNodeAttr("_start_time", it.second.start_time, it.second.recv); + if (it.second.real_recv != it.second.recv) { + AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv); + } + } + } + + VLOG(1) << "Added send/recv: controls=" << num_control + << ", data=" << num_data; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph_partition.h b/tensorflow/core/graph/graph_partition.h new file mode 100644 index 0000000000..eb88ff71b1 --- /dev/null +++ b/tensorflow/core/graph/graph_partition.h @@ -0,0 +1,77 @@ +#ifndef TENSORFLOW_GRAPH_GRAPH_PARTITION_H_ +#define TENSORFLOW_GRAPH_GRAPH_PARTITION_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/costmodel.h" + +namespace tensorflow { + +struct PartitionOptions { + // A function that returns a location for the execution of a given + // Node. + typedef std::function NodeToLocFunc; + NodeToLocFunc node_to_loc = nullptr; + + // A function that returns a unique graph node name with the given + // prefix. + typedef std::function NewNameFunc; + NewNameFunc new_name = nullptr; + + // A function that returns the incarnation of a device given the + // device's fullname. If not found, GetIncarnationFunc should return + // kIlledgalIncarnation. + static const uint64 kIllegalIncarnation = 0; + typedef std::function GetIncarnationFunc; + GetIncarnationFunc get_incarnation = nullptr; + + // True if all the control flow "code" has already been added. The + // control flow code needs to be added when we still have the entire + // graph before any partitioning. So this flag should be false for + // the first partitioning but true for all subsequent partitioning. + // + // TODO(yuanbyu): We could also make the addition of the control + // flow code incremental based on 'node_to_loc'. This makes the + // communication a broadcast tree, which could be more efficient when + // the number of participating devices is large. + bool control_flow_added; + + // A function that returns the data type into which the tensor + // should be cast before sent over the wire. + typedef std::function ShouldCastFunc; + ShouldCastFunc should_cast = nullptr; + + // Schedule the execution of the recvs based on their start times + // computed by some scheduling algorithm. The recvs are divided into + // epochs based on their start times. A recv is enabled only when + // execution reaches its epoch - N for some predefined N. + bool scheduling_for_recvs = false; + // The start time for each node in the graph computed by some scheduling + // algorithm. If 'need_to_record_start_times' is true, we record them + // in the graph as a node attribute. + bool need_to_record_start_times = false; + std::vector start_times; +}; + +// Partition "input" graph into a set of graphs, one per location. +// The location for node n is derived by calling opts.node_to_loc(n). +// New nodes added by Partition use "opts.new_name(old_name)" to +// generate node names. +// +// Stores the partitions in *partitions. +Status Partition(const PartitionOptions& opts, Graph* input, + std::unordered_map* partitions); + +// Add control edges to the partitions to control the ordering +// and timing of the recv nodes based on the start times calculated +// using some scheduling algorithm. +Status AddControlEdges(const PartitionOptions& opts, + std::unordered_map* partitions); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_GRAPH_PARTITION_H_ diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc new file mode 100644 index 0000000000..d912c94025 --- /dev/null +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -0,0 +1,316 @@ +#include "tensorflow/core/graph/graph_partition.h" + +#include + +#include +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/control_flow_ops.h" +#include "tensorflow/cc/ops/random_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/equal_graph_def.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace { + +const char gpu_device[] = "/job:a/replica:0/task:0/gpu:0"; + +string SplitByDevice(const Node* node) { return node->assigned_device_name(); } + +string DeviceName(const Node* node) { + char first = node->name()[0]; + if (first == 'G') { + return gpu_device; + } else { + const string cpu_prefix = "/job:a/replica:0/task:0/cpu:"; + int index = first - 'A'; + return strings::StrCat(cpu_prefix, index); + } +} + +void Partition(const GraphDef& graph_def, + std::unordered_map* partitions) { + Graph g(OpRegistry::Global()); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &g)); + + // Assigns devices to each node. Uses 1st letter of the node name as + // the device index. + for (Node* node : g.nodes()) { + node->set_assigned_device_name(DeviceName(node)); + } + + PartitionOptions popts; + popts.node_to_loc = SplitByDevice; + popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); }; + popts.get_incarnation = [](const string& name) { + return (name[0] - 'A') + 100; + }; + popts.control_flow_added = false; + Status s = Partition(popts, &g, partitions); + CHECK(s.ok()) << s; +} + +void CheckLoopConstruction(const GraphDef& graph_def) { + std::unordered_map partitions; + Partition(graph_def, &partitions); + GraphConstructorOptions opts; + for (const auto& kv : partitions) { + const GraphDef& gdef = kv.second; + bool has_control_enter = false; + bool has_control_merge = false; + bool has_control_switch = false; + bool has_control_next = false; + for (const NodeDef& ndef : gdef.node()) { + // _recvs must have a control input + if (ndef.op() == "_Recv") { + bool has_control = false; + for (const string& input_name : ndef.input()) { + if (StringPiece(input_name).starts_with("^")) { + has_control = true; + break; + } + } + EXPECT_TRUE(has_control); + } + // Must have a control loop + if (StringPiece(ndef.name()).starts_with("_cloop")) { + if (ndef.op() == "Enter") { + has_control_enter = true; + } + if (ndef.op() == "Merge") { + has_control_merge = true; + } + if (ndef.op() == "Switch") { + has_control_switch = true; + } + if (ndef.op() == "NextIteration") { + has_control_next = true; + } + } + } + EXPECT_TRUE(has_control_enter); + EXPECT_TRUE(has_control_merge); + EXPECT_TRUE(has_control_switch); + EXPECT_TRUE(has_control_next); + } +} + +REGISTER_OP("Input").Output("o: float"); +REGISTER_OP("BoolInput").Output("o: bool"); +REGISTER_OP("Cross").Input("a: float").Input("b: float").Output("o: float"); + +Node* Input(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("Input", opts); +} + +Node* BoolInput(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("BoolInput", opts); +} + +Node* Cross(ops::NodeOut a, ops::NodeOut b, + const GraphDefBuilder::Options& opts) { + return ops::BinaryOp("Cross", a, b, opts); +} + +class GraphPartitionTest : public ::testing::Test { + protected: + GraphPartitionTest() + : in_(GraphDefBuilder::kFailImmediately), + builder_a_(GraphDefBuilder::kFailImmediately), + builder_b_(GraphDefBuilder::kFailImmediately), + a_opts_(builder_a_.opts().WithDevice("/job:a/replica:0/task:0/cpu:0")), + b_opts_(builder_b_.opts().WithDevice("/job:a/replica:0/task:0/cpu:1")) { + RequireDefaultOps(); + } + + const GraphDef& ToGraphDef() { + in_.ToGraphDef(&in_graph_def_); + return in_graph_def_; + } + + void ExpectMatchA() { + GraphDef graph_def; + builder_a_.ToGraphDef(&graph_def); + string a = "/job:a/replica:0/task:0/cpu:0"; + TF_EXPECT_GRAPH_EQ(graph_def, partitions_[a]); + } + + void ExpectMatchB() { + GraphDef graph_def; + builder_b_.ToGraphDef(&graph_def); + string b = "/job:a/replica:0/task:0/cpu:1"; + TF_EXPECT_GRAPH_EQ(graph_def, partitions_[b]); + } + + GraphDefBuilder in_; + GraphDef in_graph_def_; + GraphDefBuilder builder_a_; + GraphDefBuilder builder_b_; + GraphDefBuilder::Options a_opts_; + GraphDefBuilder::Options b_opts_; + std::unordered_map partitions_; +}; + +TEST_F(GraphPartitionTest, SingleDevice) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Cross(a1, a1, in_.opts().WithName("A2")); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(1, partitions_.size()); + + a1 = Input(a_opts_.WithName("A1")); + Cross(a1, a1, a_opts_.WithName("A2")); + ExpectMatchA(); +} + +TEST_F(GraphPartitionTest, CrossDeviceData) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(a1, b1, in_.opts().WithName("B2")); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + _Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0")); + ExpectMatchA(); + + b1 = Input(b_opts_.WithName("B1")); + Node* recv = + _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1")); + Cross(recv, b1, b_opts_.WithName("B2")); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDeviceControl) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(b1, b1, in_.opts().WithName("B2").WithControlInput(a1)); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + Node* c = EmptyConst(a_opts_.WithName("A1/_0").WithControlInput(a1)); + _Send(c, "edge_3_A1", a, 82, b, a_opts_.WithName("A1/_1")); + ExpectMatchA(); + + Node* recv = + _Recv(DT_FLOAT, "edge_3_A1", a, 82, b, b_opts_.WithName("A1/_2")); + Node* id = Identity(recv, b_opts_.WithName("A1/_3")); + b1 = Input(b_opts_.WithName("B1")); + Cross(b1, b1, b_opts_.WithName("B2").WithControlInput(id)); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(a1, b1, in_.opts().WithName("B2")); + Cross(a1, a1, in_.opts().WithName("B3")); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + _Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0")); + ExpectMatchA(); + + Node* recv = + _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1")); + b1 = Input(b_opts_.WithName("B1")); + Cross(recv, b1, b_opts_.WithName("B2")); + Cross(recv, recv, b_opts_.WithName("B3")); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(b1, b1, in_.opts().WithName("B2").WithControlInput(a1)); + Input(in_.opts().WithName("B3").WithControlInput(a1)); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + Node* c = EmptyConst(a_opts_.WithName("A1/_0").WithControlInput(a1)); + _Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1")); + ExpectMatchA(); + + Node* recv = + _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2")); + Node* id = Identity(recv, b_opts_.WithName("A1/_3")); + b1 = Input(b_opts_.WithName("B1")); + Cross(b1, b1, b_opts_.WithName("B2").WithControlInput(id)); + Input(b_opts_.WithName("B3").WithControlInput(id)); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDevice_DataControl) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(a1, b1, in_.opts().WithName("B2")); + Input(in_.opts().WithName("B3").WithControlInput(a1)); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + Node* c = EmptyConst(a_opts_.WithName("A1/_0").WithControlInput(a1)); + // NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could + // use A1/_0 -> A1/_4 as the control as a minor optimization. + _Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1")); + _Send(a1, "edge_2_A1", a, 82, b, a_opts_.WithName("A1/_4")); + ExpectMatchA(); + + Node* recv1 = + _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2")); + Node* id1 = Identity(recv1, b_opts_.WithName("A1/_3")); + Node* recv2 = + _Recv(DT_FLOAT, "edge_2_A1", a, 82, b, b_opts_.WithName("A1/_5")); + b1 = Input(b_opts_.WithName("B1")); + Cross(recv2, b1, b_opts_.WithName("B2")); + Input(b_opts_.WithName("B3").WithControlInput(id1)); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDeviceLoop) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = BoolInput(in_.opts().WithName("A1")); + Node* a2 = Enter(a1, "foo", in_.opts().WithName("A2")); + Node* a3 = Merge({a2, {"A5", 0, DT_BOOL}}, in_.opts().WithName("A3")); + LoopCond(a3, in_.opts().WithName("A4")); + Node* b1 = Identity(a3, in_.opts().WithName("B1")); + NextIteration(b1, in_.opts().WithName("A5")); + + CheckLoopConstruction(ToGraphDef()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc new file mode 100644 index 0000000000..f7a8ffde89 --- /dev/null +++ b/tensorflow/core/graph/graph_test.cc @@ -0,0 +1,252 @@ +#include "tensorflow/core/graph/graph.h" + +#include +#include +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +class GraphTest : public ::testing::Test { + protected: + GraphTest() : graph_(OpRegistry::Global()) { RequireDefaultOps(); } + ~GraphTest() override {} + + static void VerifyNodes(Node* node, std::vector expected_in, + std::vector expected_out) { + std::vector in; + for (const Edge* e : node->in_edges()) { + in.push_back(e->src()); + } + EXPECT_EQ(Stringify(expected_in), Stringify(in)); + + std::vector out; + for (const Edge* e : node->out_edges()) { + out.push_back(e->dst()); + } + EXPECT_EQ(Stringify(expected_out), Stringify(out)); + } + + Node* AddNodeWithName(const string& name) { + Node* node; + TF_CHECK_OK(NodeBuilder(name, "NoOp").Finalize(&graph_, &node)); + return node; + } + + Graph graph_; + + private: + // Convert a list of nodes to a sorted list of strings so failure messages + // are readable. + static std::vector Stringify(const std::vector& nodes) { + std::vector result; + for (Node* n : nodes) { + result.push_back(n->DebugString()); + } + std::sort(result.begin(), result.end()); + return result; + } +}; + +TEST_F(GraphTest, Constructor) { + Node* source = graph_.source_node(); + EXPECT_NE(source, nullptr); + Node* sink = graph_.sink_node(); + EXPECT_NE(sink, nullptr); + VerifyNodes(source, {}, {sink}); + VerifyNodes(sink, {source}, {}); + EXPECT_EQ(2, graph_.num_node_ids()); +} + +TEST_F(GraphTest, RemoveThenAdd) { + AddNodeWithName("A"); + Node* b = AddNodeWithName("B"); + const int b_id = b->id(); + AddNodeWithName("C"); + EXPECT_EQ(5, graph_.num_node_ids()); + graph_.RemoveNode(b); + EXPECT_EQ(5, graph_.num_node_ids()); + Node* d = AddNodeWithName("D"); + EXPECT_NE(b_id, d->id()); // Ids should not be reused. + EXPECT_EQ(6, graph_.num_node_ids()); +} + +TEST_F(GraphTest, InNodesAndOutNodes) { + Node* a = AddNodeWithName("A"); + Node* b = AddNodeWithName("B"); + Node* c = AddNodeWithName("C"); + graph_.RemoveNode(b); + Node* d = AddNodeWithName("D"); + + const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a); + graph_.AddControlEdge(a, graph_.sink_node()); + graph_.AddEdge(a, 0, c, 0); + graph_.AddControlEdge(c, graph_.sink_node()); + + EXPECT_EQ("A", a->name()); + VerifyNodes(a, {graph_.source_node()}, {c, graph_.sink_node()}); + + EXPECT_EQ("C", c->name()); + VerifyNodes(c, {a}, {graph_.sink_node()}); + + EXPECT_EQ("D", d->name()); + VerifyNodes(d, {}, {}); + + VerifyNodes(graph_.source_node(), {}, {a, graph_.sink_node()}); + VerifyNodes(graph_.sink_node(), {a, c, graph_.source_node()}, {}); + + graph_.RemoveEdge(source_to_a); + VerifyNodes(a, {}, {c, graph_.sink_node()}); + VerifyNodes(graph_.source_node(), {}, {graph_.sink_node()}); // no more a + + graph_.RemoveNode(c); + VerifyNodes(a, {}, {graph_.sink_node()}); // no more c + VerifyNodes(graph_.sink_node(), {a, graph_.source_node()}, {}); // no more c + EXPECT_EQ(6, graph_.num_node_ids()); + EXPECT_EQ(5, graph_.num_edge_ids()); +} + +TEST_F(GraphTest, NodeIteration) { + // Set up the graph with some holes due to removals. + Node* a = AddNodeWithName("A"); + Node* b = AddNodeWithName("B"); + Node* c = AddNodeWithName("C"); + graph_.RemoveNode(b); + Node* d = AddNodeWithName("D"); + const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a); + graph_.AddControlEdge(a, graph_.sink_node()); + graph_.AddEdge(a, 0, c, 0); + graph_.AddControlEdge(c, graph_.sink_node()); + graph_.RemoveEdge(source_to_a); + graph_.RemoveNode(c); + + // expected = set of all node DebugStrings we expect in the graph + std::set expected; + expected.insert(graph_.source_node()->DebugString()); + expected.insert(a->DebugString()); + expected.insert(d->DebugString()); + expected.insert(graph_.sink_node()->DebugString()); + + // Verify that iterating through ids gets the same set of nodes. + std::set actual; + for (int id = 0; id < graph_.num_node_ids(); ++id) { + Node* node = graph_.FindNodeId(id); + if (node != nullptr) { + actual.insert(node->DebugString()); + } + } + EXPECT_EQ(expected, actual); + + // Verify that range-based for loop gets the same set of nodes. + actual.clear(); + for (Node* node : graph_.nodes()) { + actual.insert(node->DebugString()); + } + EXPECT_EQ(expected, actual); +} + +static void CheckType(Node* node, bool b) { + EXPECT_TRUE(b) << node->DebugString(); + // Make sure none of the other IsFoo() methods return true. + int count = 0; + if (node->IsSource()) count++; + if (node->IsSink()) count++; + if (node->IsOp()) count++; + EXPECT_EQ(1, count) << node->DebugString(); +} + +TEST_F(GraphTest, Type) { + Node* op = AddNodeWithName("A"); + CheckType(graph_.source_node(), graph_.source_node()->IsSource()); + CheckType(graph_.sink_node(), graph_.sink_node()->IsSink()); + CheckType(op, op->IsOp()); +} + +// Convert edge iteration results into a sorted string. +static string EdgeIter(const Graph& g) { + std::vector > edges; + for (const Edge* e : g.edges()) { + edges.push_back(std::make_pair(e->src()->id(), e->dst()->id())); + } + std::sort(edges.begin(), edges.end()); + string result; + for (auto& p : edges) { + strings::StrAppend(&result, p.first, "->", p.second, ";"); + } + return result; +} + +TEST_F(GraphTest, EdgeIteration) { + EXPECT_EQ("0->1;", EdgeIter(graph_)); + + Node* a = AddNodeWithName("A"); + Node* b = AddNodeWithName("B"); + EXPECT_EQ("0->1;", EdgeIter(graph_)); // Since a,b are currently disconnected + + graph_.AddEdge(a, 0, b, 0); + EXPECT_EQ("0->1;2->3;", EdgeIter(graph_)); + + graph_.AddControlEdge(graph_.source_node(), a); + graph_.AddControlEdge(b, graph_.sink_node()); + EXPECT_EQ("0->1;0->2;2->3;3->1;", EdgeIter(graph_)); + + graph_.AddEdge(a, 1, a, 0); + EXPECT_EQ("0->1;0->2;2->2;2->3;3->1;", EdgeIter(graph_)); +} + +TEST_F(GraphTest, NewName) { + string a1 = graph_.NewName("A"); + string a2 = graph_.NewName("A"); + string b1 = graph_.NewName("B"); + EXPECT_NE(a1, a2); + EXPECT_NE(a1, b1); + EXPECT_NE(a2, b1); + EXPECT_TRUE(StringPiece(a1).starts_with("A")) << a1; +} + +REGISTER_OP("Input").Output("o: float"); +REGISTER_OP("In2Out1").Input("a: float").Input("b: float").Output("o: float"); + +static void BM_InEdgeIteration(int iters, int num_nodes) { + testing::StopTiming(); + string s; + for (int in = 0; in < 10; in++) { + s += strings::Printf("node { name: 'in%04d' op: 'Input' }", in); + } + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int op = 0; op < num_nodes; op++) { + s += strings::Printf( + "node { name: 'op%04d' op: 'In2Out1' input: ['in%04d', 'in%04d' ] }", + op, rnd.Uniform(10), rnd.Uniform(10)); + } + + Graph graph(OpRegistry::Global()); + GraphDef graph_def; + CHECK(protobuf::TextFormat::ParseFromString(s, &graph_def)); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); + + int64 sum = 0; + testing::StartTiming(); + for (int i = 0; i < iters; i += graph.num_node_ids()) { + for (const Node* node : graph.nodes()) { + for (auto e : node->in_edges()) { + sum += e->id(); + } + } + } + VLOG(1) << sum; +} +BENCHMARK(BM_InEdgeIteration)->Range(10, 100000); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc new file mode 100644 index 0000000000..8c34323dbe --- /dev/null +++ b/tensorflow/core/graph/node_builder.cc @@ -0,0 +1,115 @@ +#include "tensorflow/core/graph/node_builder.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +NodeBuilder::NodeBuilder(const string& name, const string& op_name, + const OpRegistryInterface* op_registry) + : def_builder_(name, op_name, op_registry) {} + +NodeBuilder::NodeBuilder(const string& name, const OpDef* op_def) + : def_builder_(name, op_def) {} + +NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) { + inputs_.emplace_back(src_node, src_index); + DataType dt; + if (GetOutputType(src_node, src_index, &dt)) { + def_builder_.Input(src_node->name(), src_index, dt); + } + return *this; +} + +NodeBuilder& NodeBuilder::Input(NodeOut src) { + if (src.error) { + AddIndexError(src.node, src.index); + } else { + inputs_.emplace_back(src.node, src.index); + def_builder_.Input(src.name, src.index, src.dt); + } + return *this; +} + +NodeBuilder& NodeBuilder::Input(gtl::ArraySlice src_list) { + std::vector srcs; + srcs.reserve(src_list.size()); + for (const auto& node_out : src_list) { + if (node_out.error) { + AddIndexError(node_out.node, node_out.index); + } else { + srcs.emplace_back(node_out.name, node_out.index, node_out.dt); + inputs_.emplace_back(node_out.node, node_out.index); + } + } + def_builder_.Input(srcs); + return *this; +} + +NodeBuilder& NodeBuilder::ControlInput(Node* src_node) { + control_inputs_.emplace_back(src_node); + def_builder_.ControlInput(src_node->name()); + return *this; +} + +NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice src_nodes) { + control_inputs_.insert(control_inputs_.end(), src_nodes.begin(), + src_nodes.end()); + for (Node* src_node : src_nodes) { + def_builder_.ControlInput(src_node->name()); + } + return *this; +} + +NodeBuilder& NodeBuilder::Device(const string& device_spec) { + def_builder_.Device(device_spec); + return *this; +} + +Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const { + // In case of error, set *created_node to nullptr. + if (created_node != nullptr) *created_node = nullptr; + if (!errors_.empty()) { + return errors::InvalidArgument(str_util::Join(errors_, "\n")); + } + + NodeDef node_def; + TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def)); + TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def())); + Status status; + Node* node = graph->AddNode(node_def, &status); + if (!status.ok()) return status; + + for (size_t i = 0; i < inputs_.size(); ++i) { + if (inputs_[i].node != nullptr) { // Skip back edges. + graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i); + } + } + for (Node* control_input : control_inputs_) { + graph->AddControlEdge(control_input, node); + } + if (created_node != nullptr) *created_node = node; + return Status::OK(); +} + +void NodeBuilder::AddIndexError(Node* node, int i) { + if (node == nullptr) { + errors_.emplace_back( + strings::StrCat("Attempt to add nullptr Node to node with type", + def_builder_.op_def().name())); + } else { + errors_.emplace_back( + strings::StrCat("Attempt to add output ", i, " of ", node->name(), + " not in range [0, ", node->num_outputs(), + ") to node with type ", def_builder_.op_def().name())); + } +} + +bool NodeBuilder::GetOutputType(Node* node, int i, DataType* dt) { + bool error; + *dt = SafeGetOutput(node, i, &error); + if (error) AddIndexError(node, i); + return !error; +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h new file mode 100644 index 0000000000..dd34b97f23 --- /dev/null +++ b/tensorflow/core/graph/node_builder.h @@ -0,0 +1,146 @@ +#ifndef TENSORFLOW_GRAPH_NODE_BUILDER_H_ +#define TENSORFLOW_GRAPH_NODE_BUILDER_H_ + +#include +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// This is a helper for creating a Node and adding it to a Graph. +// Internally, it uses a NodeDefBuilder to automatically set attrs +// that can be inferred from the inputs, and use default values +// (where they exist) for unspecified attrs. Example usage: +// +// Node* node; +// Status status = NodeBuilder(node_name, op_name) +// .Input(...) +// .Attr(...) +// .Finalize(&graph, &node); +// if (!status.ok()) return status; +// // Use node here. +class NodeBuilder { + public: + // For specifying the output of a Node to provide to one of the Input() + // functions below. It supports both regular inputs (where you are + // connecting to an existing Node*), and inputs from outside the graph + // (or haven't been added to the graph yet, like back edges, where + // you don't have a Node*). Both types can be mixed, e.g. in an + // ArraySlice. + struct NodeOut { + // For referencing an existing Node. + NodeOut(Node* n, int i = 0) // NOLINT(runtime/explicit) + : node(n), + error(false), + name(node != nullptr ? node->name() : (error = true, "")), + index(i), + dt(SafeGetOutput(node, i, &error)) {} + + // For referencing Nodes not in the graph being built. It is + // useful when preparing a graph for ExtendSession or creating a + // back edge to a node that hasn't been added to the graph yet, + // but will be. + NodeOut(const string& name, int i, DataType t) + : node(nullptr), error(false), name(name), index(i), dt(t) {} + + // Default constructor for std::vector. + NodeOut() {} + + Node* node = nullptr; + // error is set to true if: + // * the NodeOut was default constructed and never overwritten, + // * a nullptr Node* was passed to the NodeOut constructor, or + // * an out-of-range index was passed to the NodeOut constructor. + bool error = true; + string name; + int index = 0; + DataType dt = DT_FLOAT; + }; + + // Specify the name and the Op (either via an OpDef or the name of + // the Op plus a registry) for the Node. Other fields are + // specified by calling the methods below. + // REQUIRES: The OpDef must satisfy ValidateOpDef(). + NodeBuilder(const string& name, const string& op_name, + const OpRegistryInterface* op_registry = OpRegistry::Global()); + NodeBuilder(const string& name, const OpDef* op_def); + + // You must call one Input() function per input_arg in the Op, + // *and in the same order as the input_args appear in the OpDef.* + + // For inputs that take a single tensor. + NodeBuilder& Input(Node* src_node, int src_index = 0); + NodeBuilder& Input(NodeOut src); + + // For inputs that take a list of tensors. + NodeBuilder& Input(gtl::ArraySlice src_list); + + // Require that this node run after src_node(s). + NodeBuilder& ControlInput(Node* src_node); + NodeBuilder& ControlInputs(gtl::ArraySlice src_nodes); + + // Sets the "requested device spec" in the NodeDef (not the + // "assigned device" in the Node). + NodeBuilder& Device(const string& device_spec); + + // Set the value of an attr. attr_name must match the name of one of + // attrs defined by the Op, and value must have the corresponding type + // (see SetAttrValue() in ../framework/attr_value_util.h for legal + // types for value). Note that attrs will be set automatically if + // they can be determined by the inputs. + template + NodeBuilder& Attr(const string& attr_name, T&& value); + template + NodeBuilder& Attr(const string& attr_name, std::initializer_list value); + + // Validates the described node and adds it to *graph, adding edges + // for all (non-back) inputs. If created_node is not nullptr, + // *created_node will be set to the new node (or nullptr on error). + Status Finalize(Graph* graph, Node** created_node) const; + + private: + static DataType SafeGetOutput(Node* node, int i, bool* error) { + if (node != nullptr && i >= 0 && i < node->num_outputs()) { + *error = false; + return node->output_type(i); + } else { + *error = true; + return DT_FLOAT; + } + } + + // If SafeGetOutput indicates a range error, add it to errors_. + void AddIndexError(Node* node, int i); + + // Set *dt and returns true if i is in range. Combines + // SafeGetOutput() and AddIndexError(). + bool GetOutputType(Node* node, int i, DataType* dt); + + NodeDefBuilder def_builder_; + std::vector inputs_; + std::vector control_inputs_; + std::vector errors_; +}; + +// IMPLEMENTATION ------------------------------------------------------------- + +template +inline NodeBuilder& NodeBuilder::Attr(const string& attr_name, T&& value) { + def_builder_.Attr(attr_name, std::forward(value)); + return *this; +} + +template +NodeBuilder& NodeBuilder::Attr(const string& attr_name, + std::initializer_list value) { + def_builder_.Attr(attr_name, value); + return *this; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_NODE_BUILDER_H_ diff --git a/tensorflow/core/graph/node_builder_test.cc b/tensorflow/core/graph/node_builder_test.cc new file mode 100644 index 0000000000..9f667d00e4 --- /dev/null +++ b/tensorflow/core/graph/node_builder_test.cc @@ -0,0 +1,59 @@ +#include "tensorflow/core/graph/node_builder.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include + +namespace tensorflow { +namespace { + +REGISTER_OP("Source").Output("o: out_types").Attr("out_types: list(type)"); +REGISTER_OP("Sink").Input("i: T").Attr("T: type"); + +TEST(NodeBuilderTest, Simple) { + RequireDefaultOps(); + Graph graph(OpRegistry::Global()); + Node* source_node; + EXPECT_OK(NodeBuilder("source_op", "Source") + .Attr("out_types", {DT_INT32, DT_STRING}) + .Finalize(&graph, &source_node)); + ASSERT_TRUE(source_node != nullptr); + + // Try connecting to each of source_node's outputs. + EXPECT_OK(NodeBuilder("sink1", "Sink") + .Input(source_node) + .Finalize(&graph, nullptr)); + EXPECT_OK(NodeBuilder("sink2", "Sink") + .Input(source_node, 1) + .Finalize(&graph, nullptr)); + + // Generate an error if the index is out of range. + EXPECT_FALSE(NodeBuilder("sink3", "Sink") + .Input(source_node, 2) + .Finalize(&graph, nullptr) + .ok()); + EXPECT_FALSE(NodeBuilder("sink4", "Sink") + .Input(source_node, -1) + .Finalize(&graph, nullptr) + .ok()); + EXPECT_FALSE(NodeBuilder("sink5", "Sink") + .Input({source_node, -1}) + .Finalize(&graph, nullptr) + .ok()); + + // Generate an error if the node is nullptr. This can happen when using + // GraphDefBuilder if there was an error creating the input node. + EXPECT_FALSE(NodeBuilder("sink6", "Sink") + .Input(nullptr) + .Finalize(&graph, nullptr) + .ok()); + EXPECT_FALSE(NodeBuilder("sink7", "Sink") + .Input(NodeBuilder::NodeOut(nullptr, 0)) + .Finalize(&graph, nullptr) + .ok()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc new file mode 100644 index 0000000000..2fa6f075c0 --- /dev/null +++ b/tensorflow/core/graph/optimizer_cse.cc @@ -0,0 +1,220 @@ +// This module implements a common subexpression elimination pass. We +// process the nodes in the graph in reverse postorder +// (i.e. inputs before their downstream dependencies). The rough algorithm is +// as follows: +// +// std::unordered_map available +// for each node n in forward topological order: +// h = NodeHash(n) +// if available[h] exists and Equivalent(available(h), h) +// redirect downstream uses of outputs of n to available[h] +// remove n from graph +// else +// if available[h] does not exist +// available[h] = n +// +// This is similar to the global value number algorithm describe in this +// paper: +// +// "Global code motion/global value numbering", Cliff Click, PLDI '95 +// Proceedings of the ACM SIGPLAN 1995 conference on Programming +// language design and implementation, Pages 246-257 +// http://dl.acm.org/citation.cfm?id=207154 + +#include "tensorflow/core/graph/optimizer_cse.h" + +#include + +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +class OptimizerCSE { + public: + explicit OptimizerCSE(Graph* g) : g_(g) {} + + void Optimize(std::function consider_fn); + + private: + struct Scratch; + + static size_t NodeHash(const Node* n); + static bool Equivalent(const Node* a, const Node* b, Scratch* s); + static bool EqualAttrs(const Node* a, const Node* b, Scratch* s); + + Graph* g_; +}; + +static void FillInputs(const Node* n, + gtl::InlinedVector* control_edges, + gtl::InlinedVector, 4>* in) { + DCHECK_EQ(in->size(), n->num_inputs()); + control_edges->clear(); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + control_edges->push_back(e->src()); + } else { + (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output()); + } + } + std::sort(control_edges->begin(), control_edges->end()); + if (n->op_def().is_commutative()) { + // For commutative inputs, we sort the input by the input Node* + // to get a canonical ordering (so that add(a,b) and add(b, a) will + // hash to the same value if is_commutative is true for 'add'). + std::sort(in->begin(), in->end()); + } +} + +static size_t kIllegalNodeHash = 0; + +size_t OptimizerCSE::NodeHash(const Node* n) { + const DataTypeVector& out = n->output_types(); + string str_to_hash = strings::StrCat(n->type_string(), out.size()); + for (DataType dt : out) { + strings::StrAppend(&str_to_hash, dt); + } + + const int N_in = n->num_inputs(); + strings::StrAppend(&str_to_hash, N_in); + gtl::InlinedVector control_edges; + gtl::InlinedVector, 4> in(N_in); + FillInputs(n, &control_edges, &in); + for (const auto& edge : in) { + strings::StrAppend(&str_to_hash, edge.first->id(), edge.second); + } + + size_t h = Hash64(str_to_hash); + +#if !defined(__ANDROID__) && !defined(ANDROID) + // Hash the attrs. For example, this makes sure different constants + // end up in different hash buckets. + string tmp; + for (const auto& attr : n->def().attr()) { + tmp = attr.first; + attr.second.AppendToString(&tmp); + // Add hashes of attrs, so the order of attrs doesn't matter. + h += Hash32(tmp.data(), tmp.size(), 0x87341245); + } +#endif + + if (h == kIllegalNodeHash) h = kIllegalNodeHash + 1; + return h; +} + +struct OptimizerCSE::Scratch { + // For EqualAttrs(): + string a; + string b; +}; + +bool OptimizerCSE::EqualAttrs(const Node* a, const Node* b, Scratch* scratch) { + if (a->def().attr_size() != b->def().attr_size()) return false; + + for (const auto& attr : b->def().attr()) { + auto iter = a->def().attr().find(attr.first); + if (iter == a->def().attr().end()) return false; + // Note: it should be safe to compare proto serializations of the attr + // values since at most one field should be set in each (indeed, it + // should be the same field). + iter->second.SerializeToString(&scratch->a); + attr.second.SerializeToString(&scratch->b); + if (scratch->a != scratch->b) return false; + } + return true; +} + +static bool HasRefInput(const Node* n) { + for (auto dt : n->input_types()) { + if (IsRefType(dt)) return true; + } + return false; +} + +bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) { + // Different op names are different + if (a->type_string() != b->type_string()) return false; + + // Never consider stateful nodes (such as non-const inputs) equivalent. + if (a->op_def().is_stateful()) return false; + + // For now, we consider any node that takes a ref input to not be + // equivalent to any other node. + if (HasRefInput(a) || HasRefInput(b)) return false; + + // Compare attrs. Note that equal attrs implies equal input and + // output types. + if (!EqualAttrs(a, b, scratch)) return false; + + // Compare input sources + if (a->num_inputs() != b->num_inputs()) return false; + const int N_in = a->num_inputs(); + gtl::InlinedVector a_control_edges; + gtl::InlinedVector b_control_edges; + gtl::InlinedVector, 4> a_in(N_in); + gtl::InlinedVector, 4> b_in(N_in); + FillInputs(a, &a_control_edges, &a_in); + FillInputs(b, &b_control_edges, &b_in); + if (a_in != b_in) return false; + if (a_control_edges != b_control_edges) return false; + + return true; +} + +void OptimizerCSE::Optimize(std::function consider_fn) { + // This very simple implementation works if the whole graph is one + // giant basic block (because we just traverse nodes in a + // topological order). We'll need to do something more + // sophisticated when we have control flow/loops/etc. + + // TODO(jeff): We need to handle Update nodes specially, but dealing + // with more general control flow will also solve this issue, and for + // now, our updates are almost always the most downstream nodes in + // the graph. + std::vector order; + GetReversePostOrder(*g_, &order); + + // Our value is just a single Node*, meaning we keep just a single + // candidate for a given node hash value. This may cause us to + // (rarely) lose some optimization opportunities if there are + // hash collisions, but it allows us to avoid having the value + // be a set (or equivalent). + std::unordered_map available; + + // Scratch space for Equivalent calls. Allocated here and passed in to + // Equivalent to avoid allocation inside the loop below. + Scratch scratch; + for (Node* n : order) { + if (!n->IsOp()) continue; + + // See if we should consider this node at all + if (consider_fn != nullptr && !consider_fn(n)) continue; + + size_t h = NodeHash(n); + Node** candidate = &available[h]; + if (*candidate == nullptr) { + // No existing match: insert "n" into the hash table under "h" + *candidate = n; + } else if (Equivalent(*candidate, n, &scratch)) { + VLOG(1) << "CSE: equivalent: " << (*candidate)->name() << " and " + << n->name(); + // *candidate and n are equivalent. Therefore, we can replace + // n with *candidate by fixing up outgoing edges from "n" to instead + // come from "*candidate", and then delete n from the graph + for (const Edge* e : n->out_edges()) { + g_->AddEdge(*candidate, e->src_output(), e->dst(), e->dst_input()); + } + g_->RemoveNode(n); + } + } +} + +void OptimizeCSE(Graph* g, std::function consider_fn) { + OptimizerCSE opt(g); + opt.Optimize(consider_fn); +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/optimizer_cse.h b/tensorflow/core/graph/optimizer_cse.h new file mode 100644 index 0000000000..430c97a449 --- /dev/null +++ b/tensorflow/core/graph/optimizer_cse.h @@ -0,0 +1,19 @@ +// An optimization pass that performs common subexpression elimination. + +#ifndef TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_ +#define TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_ + +#include +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Perform common-subexpression elimination on the graph "*g". If +// "consider_fn" is not nullptr, then only nodes for which +// consider_fn(node) returns true will be considered for combining +// during the common subexpression elimination. +extern void OptimizeCSE(Graph* g, std::function consider_fn); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_ diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc new file mode 100644 index 0000000000..ebbb948fdc --- /dev/null +++ b/tensorflow/core/graph/optimizer_cse_test.cc @@ -0,0 +1,365 @@ +#include "tensorflow/core/graph/optimizer_cse.h" + +#include +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace { + +static void InitGraph(const string& s, Graph* graph) { + GraphDef graph_def; + + auto parser = protobuf::TextFormat::Parser(); + // parser.AllowRelaxedWhitespace(true); + CHECK(parser.MergeFromString(s, &graph_def)) << s; + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph)); +} + +class OptimizerCSETest : public ::testing::Test { + public: + OptimizerCSETest() : graph_(OpRegistry::Global()) { RequireDefaultOps(); } + + void InitGraph(const string& s) { + ::tensorflow::InitGraph(s, &graph_); + original_ = CanonicalGraphString(&graph_); + } + + static bool IncludeNode(const Node* n) { return n->IsOp(); } + + static string EdgeId(const Node* n, int index) { + if (index == 0) { + return n->name(); + } else if (index == Graph::kControlSlot) { + return strings::StrCat(n->name(), ":control"); + } else { + return strings::StrCat(n->name(), ":", index); + } + } + + string CanonicalGraphString(Graph* g) { + std::vector nodes; + std::vector edges; + for (const Node* n : g->nodes()) { + if (IncludeNode(n)) { + nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")")); + } + } + for (const Edge* e : g->edges()) { + if (IncludeNode(e->src()) && IncludeNode(e->dst())) { + edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->", + EdgeId(e->dst(), e->dst_input()))); + } + } + // Canonicalize + std::sort(nodes.begin(), nodes.end()); + std::sort(edges.begin(), edges.end()); + return strings::StrCat(str_util::Join(nodes, ";"), "|", + str_util::Join(edges, ";")); + } + + string DoCSE(std::function consider_fn = nullptr) { + string before = CanonicalGraphString(&graph_); + LOG(ERROR) << "Before rewrites: " << before; + + OptimizeCSE(&graph_, consider_fn); + + string result = CanonicalGraphString(&graph_); + LOG(ERROR) << "After rewrites: " << result; + return result; + } + + const string& OriginalGraph() const { return original_; } + + Graph graph_; + string original_; +}; + +REGISTER_OP("Input").Output("o: float").SetIsStateful(); + +// Note that the "rules" in these tests are not meant to be logically correct +TEST_F(OptimizerCSETest, Simple) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul)|" + "A->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, Simple_ThreeEquivalent) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);E(Mul)|" + "A->E;B->E:1"); +} + +TEST_F(OptimizerCSETest, Simple_WithFixups) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['C', 'D'] }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul);E(Mul)|" + "A->D;B->D:1;D->E;D->E:1"); +} + +TEST_F(OptimizerCSETest, Simple_Commutative) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['B', 'A'] }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul)|" + "A->D:1;B->D"); +} + +static bool IsNotMultiply(const Node* n) { return n->type_string() != "Mul"; } + +// Like Simple_Commutative, +TEST_F(OptimizerCSETest, Simple_Filtered) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['B', 'A'] }"); + EXPECT_EQ(DoCSE(IsNotMultiply), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, Simple_NotCommutative) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['B', 'A'] }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, NotEquivalent_Ops) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs1) { + // Should still do CSE for ops with attrs if they match. + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] attr { key: 'shape'" + " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] attr { key: 'shape'" + " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul)|" + "A->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs2) { + // Should still do CSE for ops with attrs if they match, even if they + // are not in the same order. + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 'a' value { i: 3 } }" + " attr { key: 't' value { type: DT_INT32 } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 't' value { type: DT_INT32 } }" + " attr { key: 'a' value { i: 3 } } }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul)|" + "A->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, SameConstants) { + // Should still do CSE for ops with constants if the values are identical + InitGraph( + "node { name: 'A' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value {" + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'B' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value {" + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), + "B(Const);D(Mul)|" + "B->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, DifferentConstants) { + // Should still do CSE for ops with extensions if the extensions are identical + InitGraph( + "node { name: 'A' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value {" + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'B' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value {" + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 100000 } } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), + "A(Const);B(Const);D(Mul)|" + "A->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, SameOps_DifferentAttrs1) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 'a' value { i: 3 } }" + " attr { key: 't' value { type: DT_INT32 } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 't' value { type: DT_INT32 } }" + " attr { key: 'a' value { i: 4 } } }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, SameOps_DifferentAttrs2) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 'a' value { i: 3 } }" + " attr { key: 't' value { type: DT_FLOAT } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 't' value { type: DT_INT32 } }" + " attr { key: 'a' value { i: 3 } } }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, NotEquivalent_Inputs) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'C'] }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, Constant_Dedup) { + Tensor a(DT_FLOAT, TensorShape({1})); + a.flat()(0) = 1.0; + Tensor b(DT_DOUBLE, TensorShape({1})); // Different type + b.flat()(0) = 1.0; + Tensor c(DT_FLOAT, TensorShape({1, 1})); // Different shape + c.flat()(0) = 1.0; + Tensor d(DT_FLOAT, TensorShape({1})); // Different value + d.flat()(0) = 2.0; + + // A graph contains a bunch of constants. + Graph g(OpRegistry::Global()); + for (auto val : {a, b, c, d, d, c, b, a}) { + test::graph::Constant(&g, val); // Node name is n/_0, n/_1, ... + } + GraphDef gdef; + test::graph::ToGraphDef(&g, &gdef); + InitGraph(gdef.DebugString()); + + EXPECT_EQ(OriginalGraph(), + "n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const);" + "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|"); + // In theory, there are 2^4 possible correct output of CSE. In this + // test, it happens happens to eliminate the first 4 nodes. + EXPECT_EQ(DoCSE(), "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|"); +} + +static void BM_CSE(int iters, int op_nodes) { + testing::StopTiming(); + string s; + for (int in = 0; in < 10; in++) { + s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in); + } + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int op = 0; op < op_nodes; op++) { + s += strings::Printf( + "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { " + "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }", + op, rnd.Uniform(10), rnd.Uniform(10)); + } + + bool first = true; + while (iters > 0) { + Graph* graph = new Graph(OpRegistry::Global()); + InitGraph(s, graph); + int N = graph->num_node_ids(); + if (first) { + testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N)); + first = false; + } + { + testing::StartTiming(); + OptimizeCSE(graph, nullptr); + testing::StopTiming(); + } + iters -= N; // Our benchmark units are individual graph nodes, + // not whole graphs + delete graph; + } +} +BENCHMARK(BM_CSE)->Arg(1000)->Arg(10000); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc new file mode 100644 index 0000000000..7910511dfb --- /dev/null +++ b/tensorflow/core/graph/subgraph.cc @@ -0,0 +1,258 @@ +#include "tensorflow/core/graph/subgraph.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// ---------------------------------------------------------------------------- +// Subgraph construction-related routines +// ---------------------------------------------------------------------------- +// TODO(vrv): Profile the unordered_set and unordered_map use in this file to +// see if we should use an alternative implementation. + +namespace { + +typedef std::unordered_map NameIndex; + +// Rewrite graph by replacing the output tensors specified in +// "fed_outputs" with special feed nodes for each specified output +// tensor, and removing any nodes that are now disconnected from the +// part of the graph that reaches the sink node. The set of special +// feed nodes added to the graph are returned in "*feed_nodes". +// +// Return true on success. On error, return false and sets *error to +// an appropriate error message (and *g is left in an indeterminate +// state). +static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, + const gtl::ArraySlice& fed_outputs, + NameIndex* name_index) { + for (const string& t : fed_outputs) { + TensorId id(ParseTensorName(t)); + + auto iter = name_index->find(id.first); + if (iter == name_index->end()) { + return errors::NotFound("FeedInputs: unable to find feed output ", t); + } + const Node* n = iter->second; + DCHECK_EQ(n->name(), id.first); + if (id.second >= n->num_outputs()) { + return errors::InvalidArgument( + "FeedInputs: ", t, " should have output index < ", n->num_outputs()); + } + + Node* recv_node; + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second), + "_Recv") + .Attr("tensor_type", BaseType(n->output_type(id.second))) + .Attr("tensor_name", t) + .Attr("send_device", device_info.name()) + .Attr("recv_device", device_info.name()) + .Attr("send_device_incarnation", + static_cast(device_info.incarnation())) + .Attr("client_terminated", true) + .Finalize(g, &recv_node)); + recv_node->set_assigned_device_name(device_info.name()); + + // Update name_index + (*name_index)[recv_node->name()] = recv_node; + g->AddControlEdge(g->source_node(), recv_node); + + // Look through edges coming out of "n" for edges whose src_output() index + // matches "output_index". If found, replace the edges with a connection + // from the special feed node. + std::vector to_remove; + for (const Edge* e : n->out_edges()) { + if (e->src_output() == id.second) { + to_remove.emplace_back(e); + } else if (e->src_output() == Graph::kControlSlot && + n->def().op() == "Placeholder") { + // When feeding a Placeholder node, any outgoing control edges + // will be replaced with a control edge from the replacement + // recv_node. + // TODO(josh11b,mrry): Come up with a more elegant way of addressing + // the general version of this problem. + to_remove.emplace_back(e); + } + } + + for (const Edge* e : to_remove) { + if (e->src_output() == id.second) { + g->AddEdge(recv_node, 0, e->dst(), e->dst_input()); + } else { + CHECK_EQ(Graph::kControlSlot, e->src_output()); + g->AddControlEdge(recv_node, e->dst()); + } + g->RemoveEdge(e); + } + } + return Status::OK(); +} + +// Augment "*g" by adding special "fetch" nodes that connect to the +// tensor outputs specified in "fetch_outputs" to retrieve the output +// of the tensors. The new nodes added are set up to execute on +// "client_device_name", and are returned in "*fetch_nodes". +// +// Return true on success. On error, return false and sets *error to +// an appropriate error message (and *g is left in an indeterminate +// state). +static Status FetchOutputs(Graph* g, const DeviceAttributes& device_info, + const gtl::ArraySlice& fetch_outputs, + NameIndex* name_index, + std::vector* fetch_nodes) { + fetch_nodes->clear(); + for (const string& t : fetch_outputs) { + // Parse t into node_name and output_index. + TensorId id(ParseTensorName(t)); + + // Find node in graph with that name. + auto iter = name_index->find(id.first); + if (iter == name_index->end()) { + return errors::NotFound("FetchOutputs node ", t, ": not found"); + } + Node* n = iter->second; + DCHECK_EQ(n->name(), id.first); + VLOG(2) << "Found fetch node for " << t; + + // Validate output_index + if (id.second >= n->num_outputs()) { + return errors::InvalidArgument("FetchOutputs ", t, + ": output index too large, must be < ", + n->num_outputs()); + } + + // Create the fetch Node and connect it up + Node* send_node; + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second), + "_Send") + .Input(n, id.second) + .Attr("tensor_name", t) + .Attr("send_device", device_info.name()) + .Attr("recv_device", device_info.name()) + .Attr("send_device_incarnation", + static_cast(device_info.incarnation())) + .Attr("client_terminated", true) + .Finalize(g, &send_node)); + send_node->set_assigned_device_name(device_info.name()); + VLOG(1) << "Created fetch node: " << SummarizeNodeDef(send_node->def()); + + // Update the index. + (*name_index)[send_node->name()] = send_node; + + g->AddControlEdge(send_node, g->sink_node()); + fetch_nodes->push_back(send_node); + } + + return Status::OK(); +} + +static bool AddNodeToTargets(const string& node_or_tensor_name, + const NameIndex& name_index, + std::unordered_set* targets) { + TensorId id = ParseTensorName(node_or_tensor_name); + auto iter = name_index.find(id.first); + if (iter == name_index.end()) { + return false; + } + const Node* n = iter->second; + if (n->name() != node_or_tensor_name) { + return false; + } + + targets->insert(n); + return true; +} + +static Status PruneForTargets(Graph* g, const NameIndex& name_index, + const std::vector& fetch_nodes, + const gtl::ArraySlice& target_nodes) { + string not_found; + std::unordered_set targets; + for (Node* n : fetch_nodes) { + if (!AddNodeToTargets(n->name(), name_index, &targets)) { + strings::StrAppend(¬_found, n->name(), " "); + } + } + for (const string& s : target_nodes) { + if (!AddNodeToTargets(s, name_index, &targets)) { + strings::StrAppend(¬_found, s, " "); + } + } + if (!not_found.empty()) { + return errors::NotFound("PruneForTargets: Some target nodes not found: ", + not_found); + } + PruneForReverseReachability(g, targets); + + return Status::OK(); +} + +} // namespace + +namespace subgraph { + +Status RewriteGraphForExecution( + Graph* g, const gtl::ArraySlice& fed_outputs, + const gtl::ArraySlice& fetch_outputs, + const gtl::ArraySlice& target_node_names, + const DeviceAttributes& device_info) { + std::unordered_set endpoints(fed_outputs.begin(), fed_outputs.end()); + for (const auto& fetch : fetch_outputs) { + if (endpoints.count(fetch) > 0) { + return errors::InvalidArgument(fetch, " is both fed and fetched."); + } + } + + // A separate index mapping name to Node*, for use by FeedInputs, + // FetchOutputs, and PruneForTargets + NameIndex name_index; + for (Node* n : g->nodes()) { + name_index[n->name()] = n; + } + + // Add the feeds. This may replace nodes in the graph, including the nodes + // currently listed in "fetch_nodes". We pass "name_index" so the index is + // kept up to date. + if (!fed_outputs.empty()) { + TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs, &name_index)); + } + + // Add the fetch nodes, also updating "name_index". + std::vector fetch_nodes; + if (!fetch_outputs.empty()) { + TF_RETURN_IF_ERROR( + FetchOutputs(g, device_info, fetch_outputs, &name_index, &fetch_nodes)); + } + + // Prune the graph to only compute what is needed for the fetch nodes and the + // targets nodes. + if (!fetch_nodes.empty() || !target_node_names.empty()) { + TF_RETURN_IF_ERROR( + PruneForTargets(g, name_index, fetch_nodes, target_node_names)); + } + + return Status::OK(); +} + +} // namespace subgraph + +} // namespace tensorflow diff --git a/tensorflow/core/graph/subgraph.h b/tensorflow/core/graph/subgraph.h new file mode 100644 index 0000000000..d2e138e8ae --- /dev/null +++ b/tensorflow/core/graph/subgraph.h @@ -0,0 +1,49 @@ +#ifndef TENSORFLOW_GRAPH_SUBGRAPH_H_ +#define TENSORFLOW_GRAPH_SUBGRAPH_H_ + +#include + +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace subgraph { + +// Rewrite the graph structure of "*g" to deal with feeding node +// outputs, fetching node outputs, and only running a subset of the +// graph. "fed_outputs" and "fetch_outputs" are both lists of +// output tensor identifiers in the form of +// "[:]", and "target_nodes_str" is a +// lists of of target node names in "*g" "g". +// +// In the resulting graph "*g", output edges in "fed_outputs" have +// been redirected to special "_recv" nodes introduced into the graph. +// If these fed nodes are not needed in order to compute the effects +// of the nodes in "targets_nodes" and "fetch_outputs", then these may +// be omitted from the graph. +// +// In the resulting graph "*g", additional "_send" nodes are connected +// to every output in "fetch_outputs". These "_send" nodes are set up +// to execute on the device described by device_info. +// +// On success, returns OK, and sets "*g" to a version of "*g" +// that represents the portions of the graph necessary for producing +// the output of all nodes listed in "target_node_names" and fetching the +// specific node outputs specified in "fetch_outputs". +// +// On failure, returns the error status. Possible errors include: +// - fed output "node:output_index" does not exist in "*g" +// - fetch output "node:output_index" does not exist in "*g" +// - target node "node" does not exist in "*g" +Status RewriteGraphForExecution( + Graph* g, const gtl::ArraySlice& fed_outputs, + const gtl::ArraySlice& fetch_outputs, + const gtl::ArraySlice& target_node_names, + const DeviceAttributes& device_info); + +} // namespace subgraph +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_SUBGRAPH_H_ diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc new file mode 100644 index 0000000000..ffb3e6e403 --- /dev/null +++ b/tensorflow/core/graph/subgraph_test.cc @@ -0,0 +1,305 @@ +#include "tensorflow/core/graph/subgraph.h" + +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/status.h" +#include + +// TODO(josh11b): Test setting the "device" field of a NodeDef. +// TODO(josh11b): Test that feeding won't prune targets. + +namespace tensorflow { +namespace { + +class SubgraphTest : public ::testing::Test { + protected: + SubgraphTest() : g_(new Graph(OpRegistry::Global())) { + RequireDefaultOps(); + device_info_.set_name("/job:a/replica:0/task:0/cpu:0"); + device_info_.set_device_type(DeviceType(DEVICE_CPU).type()); + device_info_.set_incarnation(0); + } + + ~SubgraphTest() override {} + + void ExpectOK(const string& gdef_ascii) { + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_)); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get())); + } + + Node* FindNode(const string& name) { + for (Node* n : g_->nodes()) { + if (n->name() == name) return n; + } + return nullptr; + } + + bool HasNode(const string& name) { return FindNode(name) != nullptr; } + + void ExpectNodes(const string& nodes) { + int count = 0; + std::vector actual_nodes; + for (Node* n : g_->nodes()) { + if (n->IsOp()) { + count++; + actual_nodes.push_back(n->name()); + } + } + std::sort(actual_nodes.begin(), actual_nodes.end()); + + LOG(INFO) << "Nodes present: " << str_util::Join(actual_nodes, " "); + + std::vector expected_nodes = str_util::Split(nodes, ','); + std::sort(expected_nodes.begin(), expected_nodes.end()); + for (const string& s : expected_nodes) { + Node* n = FindNode(s); + EXPECT_TRUE(n != nullptr) << s; + if (n->def().op() == "_Send" || n->def().op() == "_Recv") { + EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s; + } + } + + EXPECT_TRUE(actual_nodes.size() == expected_nodes.size()) + << "\nActual: " << str_util::Join(actual_nodes, ",") + << "\nExpected: " << str_util::Join(expected_nodes, ","); + } + + bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) { + for (const Edge* e : g_->edges()) { + if (e->src()->name() == src && e->src_output() == src_out && + e->dst()->name() == dst && e->dst_input() == dst_in) + return true; + } + return false; + } + bool HasControlEdge(const string& src, const string& dst) { + return HasEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot); + } + + string Subgraph(const string& fed_str, const string& fetch_str, + const string& targets_str) { + Graph* subgraph = new Graph(OpRegistry::Global()); + CopyGraph(*g_, subgraph); + std::vector fed = + str_util::Split(fed_str, ',', str_util::SkipEmpty()); + std::vector fetch = + str_util::Split(fetch_str, ',', str_util::SkipEmpty()); + std::vector targets = + str_util::Split(targets_str, ',', str_util::SkipEmpty()); + + Status s = subgraph::RewriteGraphForExecution(subgraph, fed, fetch, + targets, device_info_); + if (!s.ok()) { + delete subgraph; + return s.ToString(); + } + + // Replace the graph with the subgraph for the rest of the display program + g_.reset(subgraph); + return "OK"; + } + + Graph* graph() { return g_.get(); } + + private: + GraphDef gdef_; + std::unique_ptr g_; + DeviceAttributes device_info_; +}; + +REGISTER_OP("TestParams").Output("o: float"); +REGISTER_OP("TestInput").Output("a: float").Output("b: float"); +REGISTER_OP("TestRelu").Input("i: float").Output("o: float"); +REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); + +TEST_F(SubgraphTest, Targets1) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "", "t1")); + ExpectNodes("W1,input,t1"); +} + +TEST_F(SubgraphTest, Targets2) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: 'W1' input: 'input:1' }" + "node { name: 't2' op: 'TestMul' input: 'W2' input: 't1' }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "", "t2,t3_a")); + ExpectNodes("W1,W2,input,t1,t2,t3_a"); +} + +TEST_F(SubgraphTest, FedOutputs1) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("input:1", "", "t2")); + ExpectNodes("W1,W2,_recv_input_1,t1,t2"); +} + +TEST_F(SubgraphTest, FedRefNode) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }"); + EXPECT_EQ("OK", Subgraph("W1:0", "", "t1")); + ExpectNodes("_recv_W1_0,W2,t1"); + Node* n = FindNode("_recv_W1_0"); + EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0))); +} + +TEST_F(SubgraphTest, FedOutputs2) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + // We feed input:1, but nothing connects to it, so the _recv(input:1) + // node also disappears. + EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2")); + ExpectNodes("_recv_t1_0,_recv_W2_0,t2"); +} + +TEST_F(SubgraphTest, FetchOutputs1) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2")); + ExpectNodes( + "W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0"); +} + +TEST_F(SubgraphTest, FetchOutputs2) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "t3_a", "t2")); + ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0"); +} + +TEST_F(SubgraphTest, ChainOfFools) { + ExpectOK( + "node { name: 'a' op: 'TestParams' }" + "node { name: 'b' op: 'TestRelu' input: 'a'}" + "node { name: 'c' op: 'TestRelu' input: 'b'}" + "node { name: 'd' op: 'TestRelu' input: 'c'}" + "node { name: 'e' op: 'TestRelu' input: 'd'}" + "node { name: 'f' op: 'TestRelu' input: 'e'}"); + EXPECT_EQ("OK", Subgraph("c:0", "b:0,e:0", "")); + ExpectNodes("a,b,_send_b_0,_recv_c_0,d,e,_send_e_0"); + EXPECT_TRUE(HasEdge("a", 0, "b", 0)); + EXPECT_TRUE(HasEdge("b", 0, "_send_b_0", 0)); + EXPECT_TRUE(HasEdge("_recv_c_0", 0, "d", 0)); + EXPECT_TRUE(HasEdge("d", 0, "e", 0)); + EXPECT_TRUE(HasEdge("e", 0, "_send_e_0", 0)); +} + +static bool HasSubstr(const string& base, const string& substr) { + bool ok = StringPiece(base).contains(substr); + EXPECT_TRUE(ok) << base << ", expected substring " << substr; + return ok; +} + +TEST_F(SubgraphTest, Errors) { + ExpectOK( + "node { name: 'a' op: 'TestParams' }" + "node { name: 'b' op: 'TestRelu' input: 'a'}" + "node { name: 'c' op: 'TestRelu' input: 'b'}" + "node { name: 'd' op: 'TestRelu' input: 'c'}" + "node { name: 'e' op: 'TestRelu' input: 'd'}" + "node { name: 'f' op: 'TestRelu' input: 'e'}"); + // Duplicated feed and fetch + EXPECT_TRUE( + HasSubstr(Subgraph("c:0", "b:0,c:0", ""), "both fed and fetched")); + // Feed not found. + EXPECT_TRUE(HasSubstr(Subgraph("foo:0", "", ""), "unable to find")); + // Fetch not found. + EXPECT_TRUE(HasSubstr(Subgraph("", "foo:0", ""), "not found")); + // Target not found. + EXPECT_TRUE(HasSubstr(Subgraph("", "", "foo"), "not found")); +} + +REGISTER_OP("In").Output("o: float"); +REGISTER_OP("Op").Input("i: float").Output("o: float"); + +static void BM_Subgraph(int iters, int num_nodes) { + DeviceAttributes device_info; + device_info.set_name("/job:a/replica:0/task:0/cpu:0"); + device_info.set_device_type(DeviceType(DEVICE_CPU).type()); + device_info.set_incarnation(0); + + testing::StopTiming(); + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* last_node = nullptr; + for (int i = 0; i < num_nodes; i++) { + string name = strings::StrCat("N", i); + if (i > 0) { + last_node = ops::UnaryOp("Op", last_node, b.opts().WithName(name)); + } else { + last_node = ops::SourceOp("In", b.opts().WithName(name)); + } + } + TF_CHECK_OK(b.ToGraph(&g)); + } + + std::vector fed; + if (num_nodes > 1000) { + fed.push_back(strings::StrCat("N", num_nodes - 1000)); + } + std::vector fetch; + std::vector targets = {strings::StrCat("N", num_nodes - 1)}; + testing::StartTiming(); + while (--iters > 0) { + Graph* subgraph = new Graph(OpRegistry::Global()); + CopyGraph(g, subgraph); + TF_CHECK_OK(subgraph::RewriteGraphForExecution(subgraph, fed, fetch, + targets, device_info)); + delete subgraph; + } +} +BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc new file mode 100644 index 0000000000..f789110ff3 --- /dev/null +++ b/tensorflow/core/graph/tensor_id.cc @@ -0,0 +1,41 @@ +#include "tensorflow/core/graph/tensor_id.h" + +#include + +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +TensorId ParseTensorName(const string& name) { + return ParseTensorName(StringPiece(name.data(), name.size())); +} + +TensorId ParseTensorName(StringPiece name) { + // Parse either a name, or a name:digits. To do so, we go backwards + // from the end of the string, skipping over a run of digits. If + // we hit a ':' character, then we know we are in the 'name:digits' + // regime. Otherwise, the output index is implicitly 0, and the whole + // name string forms the first part of the tensor name. + // + // Equivalent to matching with this regexp: ([^:]+):(\\d+) + const char* base = name.data(); + const char* p = base + name.size() - 1; + int index = 0; + int mul = 1; + while (p > base && (*p >= '0' && *p <= '9')) { + index += ((*p - '0') * mul); + mul *= 10; + p--; + } + TensorId id; + if (p > base && *p == ':' && mul > 1) { + id.first = StringPiece(base, p - base); + id.second = index; + } else { + id.first = name; + id.second = 0; + } + return id; +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h new file mode 100644 index 0000000000..f1f3846875 --- /dev/null +++ b/tensorflow/core/graph/tensor_id.h @@ -0,0 +1,28 @@ +#ifndef TENSORFLOW_GRAPH_TENSOR_ID_H_ +#define TENSORFLOW_GRAPH_TENSOR_ID_H_ + +#include + +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +// Identifier for a tensor within a step. +// first == operation_name, second == output_index +// Note: does not own backing storage for name. +struct TensorId : public std::pair { + typedef std::pair Base; + + // Inherit the set of constructors. + using Base::pair; + + string ToString() const { return strings::StrCat(first, ":", second); } +}; + +TensorId ParseTensorName(const string& name); +TensorId ParseTensorName(StringPiece name); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_TENSOR_ID_H_ diff --git a/tensorflow/core/graph/tensor_id_test.cc b/tensorflow/core/graph/tensor_id_test.cc new file mode 100644 index 0000000000..b945774cc3 --- /dev/null +++ b/tensorflow/core/graph/tensor_id_test.cc @@ -0,0 +1,77 @@ +#include "tensorflow/core/graph/tensor_id.h" +#include +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +static string ParseHelper(const string& n) { + TensorId id = ParseTensorName(n); + return strings::StrCat(id.first, ":", id.second); +} + +TEST(TensorIdTest, ParseTensorName) { + EXPECT_EQ(ParseHelper("W1"), "W1:0"); + EXPECT_EQ(ParseHelper("weights:0"), "weights:0"); + EXPECT_EQ(ParseHelper("W1:1"), "W1:1"); + EXPECT_EQ(ParseHelper("W1:17"), "W1:17"); + EXPECT_EQ(ParseHelper("xyz1_17"), "xyz1_17:0"); +} + +static uint32 Skewed(random::SimplePhilox* rnd, int max_log) { + const uint32 space = 1 << (rnd->Rand32() % (max_log + 1)); + return rnd->Rand32() % space; +} + +static void BM_ParseTensorName(int iters, int arg) { + testing::StopTiming(); + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + std::vector names; + for (int i = 0; i < 100; i++) { + string name; + switch (arg) { + case 0: { // Generate random names + size_t len = Skewed(&rnd, 4); + while (name.size() < len) { + name += rnd.OneIn(4) ? '0' : 'a'; + } + if (rnd.OneIn(3)) { + strings::StrAppend(&name, ":", rnd.Uniform(12)); + } + break; + } + case 1: + name = "W1"; + break; + case 2: + name = "t0003"; + break; + case 3: + name = "weights"; + break; + case 4: + name = "weights:17"; + break; + default: + LOG(FATAL) << "Unexpected arg"; + break; + } + names.push_back(name); + } + testing::StartTiming(); + TensorId id; + int index = 0; + int sum = 0; + while (--iters > 0) { + id = ParseTensorName(names[index++ % names.size()]); + sum += id.second; + } + VLOG(2) << sum; // Prevent compiler from eliminating loop body +} +BENCHMARK(BM_ParseTensorName)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc new file mode 100644 index 0000000000..e49d5e819a --- /dev/null +++ b/tensorflow/core/graph/testlib.cc @@ -0,0 +1,299 @@ +#include "tensorflow/core/graph/testlib.h" + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace test { +namespace graph { + +Node* Send(Graph* g, Node* input, const string& tensor, const string& sender, + const uint64 sender_incarnation, const string& receiver) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send") + .Input(input, 0) + .Attr("tensor_name", tensor) + .Attr("send_device", sender) + .Attr("send_device_incarnation", + static_cast(sender_incarnation)) + .Attr("recv_device", receiver) + .Finalize(g, &ret)); + return ret; +} + +Node* Recv(Graph* g, const string& tensor, const string& type, + const string& sender, const uint64 sender_incarnation, + const string& receiver) { + Node* ret; + DataType dtype; + CHECK(DataTypeFromString(type, &dtype)); + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv") + .Attr("tensor_type", dtype) + .Attr("tensor_name", tensor) + .Attr("send_device", sender) + .Attr("send_device_incarnation", + static_cast(sender_incarnation)) + .Attr("recv_device", receiver) + .Finalize(g, &ret)); + return ret; +} + +Node* Constant(Graph* g, const Tensor& tensor) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const") + .Attr("dtype", tensor.dtype()) + .Attr("value", tensor) + .Finalize(g, &ret)); + return ret; +} + +Node* Constant(Graph* g, const Tensor& tensor, const string& name) { + Node* ret; + TF_CHECK_OK(NodeBuilder(name, "Const") + .Attr("dtype", tensor.dtype()) + .Attr("value", tensor) + .Finalize(g, &ret)); + return ret; +} + +Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable") + .Attr("dtype", dtype) + .Attr("shape", shape) + .Finalize(g, &ret)); + return ret; +} + +Node* Assign(Graph* g, Node* var, Node* val) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign") + .Input(var) + .Input(val) + .Attr("use_locking", true) + .Finalize(g, &ret)); + return ret; +} + +Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes, + bool keep_dims) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce) + .Input(data) + .Input(axes) + .Attr("keep_dims", keep_dims) + .Finalize(g, &ret)); + return ret; +} + +Node* QuantizeToUINT8(Graph* g, Node* data) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize") + .Input(data) + .Attr("T", DT_QUINT8) + .Attr("max_range", 1.0f) + .Attr("min_range", -1.0f) + .Finalize(g, &ret)); + return ret; +} + +Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a, + bool transpose_b) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul") + .Input(in0) + .Input(in1) + .Attr("transpose_a", transpose_a) + .Attr("transpose_b", transpose_b) + .Finalize(g, &ret)); + return ret; +} + +Node* RandomNumberGenerator(const string& op, Graph* g, Node* input, + DataType dtype) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), op) + .Input(input) + .Attr("dtype", dtype) + .Attr("seed", 0) + .Finalize(g, &ret)); + return ret; +} + +Node* RandomUniform(Graph* g, Node* input, DataType dtype) { + return RandomNumberGenerator("RandomUniform", g, input, dtype); +} + +Node* RandomGaussian(Graph* g, Node* input, DataType dtype) { + return RandomNumberGenerator("RandomStandardNormal", g, input, dtype); +} + +Node* RandomParameters(Graph* g, Node* input, DataType dtype) { + return RandomNumberGenerator("RandomParameters", g, input, dtype); +} + +Node* Unary(Graph* g, const string& func, Node* input, int index) { + Node* ret; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), func).Input(input, index).Finalize(g, &ret)); + return ret; +} + +Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), func) + .Input(in0) + .Input(in1) + .Finalize(g, &ret)); + return ret; +} + +Node* Multi(Graph* g, const string& func, gtl::ArraySlice ins) { + Node* ret; + auto b = NodeBuilder(g->NewName("n"), func); + for (Node* n : ins) b = b.Input(n); + TF_CHECK_OK(b.Finalize(g, &ret)); + return ret; +} + +Node* Identity(Graph* g, Node* input, int index) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity") + .Input(input, index) + .Finalize(g, &ret)); + return ret; +} + +Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); } + +Node* Error(Graph* g, Node* input, const string& errmsg) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error") + .Input(input) + .Attr("message", errmsg) + .Finalize(g, &ret)); + return ret; +} + +Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) { + DCHECK(out_type != invalid_type); + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType") + .Attr("TIn", out_type) + .Attr("TOut", invalid_type) + .Finalize(g, &ret)); + return ret; +} + +Node* Delay(Graph* g, Node* input, Microseconds delay_micros) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay") + .Input(input) + .Attr("micros", delay_micros.value()) + .Finalize(g, &ret)); + return ret; +} + +Node* NoOp(Graph* g, const std::vector& control_inputs) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp") + .ControlInputs(control_inputs) + .Finalize(g, &ret)); + return ret; +} + +Node* Switch(Graph* g, Node* in0, Node* in1) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch") + .Input(in0) + .Input(in1) + .Finalize(g, &ret)); + return ret; +} + +Node* Enter(Graph* g, Node* input, const string& frame_name) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter") + .Input(input) + .Attr("frame_name", frame_name) + .Finalize(g, &ret)); + return ret; +} + +Node* Exit(Graph* g, Node* input) { + Node* ret; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret)); + return ret; +} + +Node* Merge(Graph* g, Node* in0, Node* in1) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge") + .Input({in0, in1}) + .Finalize(g, &ret)); + return ret; +} + +Node* Merge(Graph* g, Node* in0, gtl::ArraySlice remaining_in) { + std::vector inputs; + inputs.reserve(remaining_in.size() + 1); + inputs.emplace_back(in0); + for (const string& in_name : remaining_in) { + inputs.emplace_back(in_name, 0, inputs[0].dt); + } + + Node* ret; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret)); + return ret; +} + +Node* Next(Graph* g, const string& name, Node* input) { + Node* ret; + TF_CHECK_OK( + NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret)); + return ret; +} + +Node* LoopCond(Graph* g, Node* input) { + Node* ret; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret)); + return ret; +} + +Node* Less(Graph* g, Node* in0, Node* in1) { + return Binary(g, "Less", in0, in1); +} + +Node* Select(Graph* g, Node* c, Node* inx, Node* iny) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select") + .Input(c) + .Input(inx) + .Input(iny) + .Finalize(g, &ret)); + return ret; +} + +Node* Cast(Graph* g, Node* in, DataType dst) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast") + .Input(in) + .Attr("DstT", dst) + .Finalize(g, &ret)); + return ret; +} + +void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } + +} // end namespace graph +} // end namespace test +} // end namespace tensorflow diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h new file mode 100644 index 0000000000..11905bbf6a --- /dev/null +++ b/tensorflow/core/graph/testlib.h @@ -0,0 +1,141 @@ +// DEPRECATED: Use GraphDefBuilder instead. + +#ifndef TENSORFLOW_GRAPH_TESTLIB_H_ +#define TENSORFLOW_GRAPH_TESTLIB_H_ + +#include +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { +namespace test { +namespace graph { + +// Converts "g" into its corresponding GraphDef "def". +// DEPRECATED: call g->ToGraphDef(def) instead. +void ToGraphDef(Graph* g, GraphDef* def); + +// A few helpers to construct a graph. + +// Adds a node in "g" producing a constant "tensor". +Node* Constant(Graph* g, const Tensor& tensor); +Node* Constant(Graph* g, const Tensor& tensor, const string& name); + +// Adds a variable in "g" of the given "shape" and "dtype". +Node* Var(Graph* g, const DataType dtype, const TensorShape& shape); + +// Adds an assign node in "g" which assigns "val" into "var". +Node* Assign(Graph* g, Node* var, Node* val); + +// Adds a send node "g" sending "input" as a named "tensor" from +// "sender" to "receiver". +Node* Send(Graph* g, Node* input, const string& tensor, const string& sender, + const uint64 sender_incarnation, const string& receiver); + +// Adds a recv node in "g" receiving a named "tensor" from "sender" +// to "receiver". +Node* Recv(Graph* g, const string& tensor, const string& type, + const string& sender, const uint64 sender_incarnation, + const string& receiver); + +// Adds a reduction "node" in "g" doing sum(data, axes). "reduce" is +// a reduction, e.g., Sum, Max, Min, Mean, etc. +Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes, + bool keep_dims = false); + +// Adds a Matmul node in g doing in0.contract(in1). +Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a, + bool transpose_b); + +// Adds a Quantize node into g that quantize floats into QUINT8. The range of +// the input float tensor is assumed to be [-1, 1]. +Node* QuantizeToUINT8(Graph* g, Node* data); + +// Adds a unary function "func" "node" in "g" taking "input". +Node* Unary(Graph* g, const string& func, Node* input, int index = 0); + +// Adds an identity node in "g" taking "input" and producing an +// identity copy. +Node* Identity(Graph* g, Node* input, int index = 0); + +// Adds a binary function "func" node in "g" taking "in0" and "in1". +// Requires that "func" name an attr-style Op. +Node* Binary(Graph* g, const string& func, Node* in0, Node* in1); + +// Adds a function "func" node in "g" taking inputs "ins". +// Requires that "func" name an attr-style Op. +Node* Multi(Graph* g, const string& func, gtl::ArraySlice ins); + +// Adds a binary add node in "g" doing in0 + in1. +Node* Add(Graph* g, Node* in0, Node* in1); + +// Generates random unit uniform distribution of the input shape. +Node* RandomUniform(Graph* g, Node* input, DataType dtype); + +// Generates random unit normal distribution of the input shape. +Node* RandomGaussian(Graph* g, Node* input, DataType dtype); + +// Generates random parameters from the truncated standard normal distribution +// of the nput shape +Node* RandomParameters(Graph* g, Node* input, DataType dtype); + +// Adds an error node in "g". The node's computation always +// generates an error with the given error message "errmsg". +Node* Error(Graph* g, Node* input, const string& errmsg); + +// Adds a node that generates a invalid ref output. +Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type); + +// Adds a node in "g". Its Compute() sleeps a while and outputs the +// input (i.e., same as identity). +Node* Delay(Graph* g, Node* input, Microseconds delay_micros); + +// Adds a no-op "node" in "g", with control inputs from all nodes in +// control_inputs vector. +Node* NoOp(Graph* g, const std::vector& control_inputs); + +// Adds a Switch node in "g". If "in1" is true, it forwards "in0" to +// output 1. Otherwise, it forwards "in0" to output 0. +Node* Switch(Graph* g, Node* in0, Node* in1); + +// Adds an Enter node in "g", which enters a new frame. +Node* Enter(Graph* g, Node* input, const string& frame_name); + +// Adds an Exit node in "g", which exits a frame. +Node* Exit(Graph* g, Node* input); + +// Adds a Merge node in "g" with two inputs "in0" and "in1". +Node* Merge(Graph* g, Node* in0, Node* in1); + +// Adds a Merge node in "g". The first input is "in0", the remaining +// inputs are only given by their names in remaining_in. +Node* Merge(Graph* g, Node* in0, gtl::ArraySlice remaining_in); + +// Adds a NextIteration node in "g", which makes its input available +// to the next iteration. +Node* Next(Graph* g, const string& name, Node* input); + +// Adds a LoopCond node in "g", representing the "pivot" termination +// condition of a loop. +Node* LoopCond(Graph* g, Node* input); + +// Adds a less node in "g", which returns true iff "in0" < "in1". +Node* Less(Graph* g, Node* in0, Node* in1); + +// Adds a select node in "g", which outputs either "inx" or "iny" +// depending on the boolean value of "c". +Node* Select(Graph* g, Node* c, Node* inx, Node* iny); + +// Casts "in" into data type "dst". +Node* Cast(Graph* g, Node* in, DataType dst); + +} // end namespace graph +} // end namespace test +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPH_TESTLIB_H_ diff --git a/tensorflow/core/graph/types.h b/tensorflow/core/graph/types.h new file mode 100644 index 0000000000..41400611a9 --- /dev/null +++ b/tensorflow/core/graph/types.h @@ -0,0 +1,17 @@ +#ifndef TENSORFLOW_GRAPH_TYPES_H_ +#define TENSORFLOW_GRAPH_TYPES_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/gtl/int_type.h" + +namespace tensorflow { + +// We model running time in microseconds. +TF_LIB_GTL_DEFINE_INT_TYPE(Microseconds, int64); + +// We model size in bytes. +TF_LIB_GTL_DEFINE_INT_TYPE(Bytes, int64); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_TYPES_H_ diff --git a/tensorflow/core/kernels/adjust_contrast_op.cc b/tensorflow/core/kernels/adjust_contrast_op.cc new file mode 100644 index 0000000000..7cc0534354 --- /dev/null +++ b/tensorflow/core/kernels/adjust_contrast_op.cc @@ -0,0 +1,121 @@ +// See docs in ../ops/image_ops.cc +#define EIGEN_USE_THREADS + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/adjust_contrast_op.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class AdjustContrastOp : public OpKernel { + public: + explicit AdjustContrastOp(OpKernelConstruction* context) : OpKernel(context) { + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& factor = context->input(1); + const Tensor& min_value = context->input(2); + const Tensor& max_value = context->input(3); + OP_REQUIRES(context, input.dims() >= 3, + errors::InvalidArgument("input must be at least 3-D, got shape", + input.shape().ShortDebugString())); + const int64 height = input.dim_size(input.dims() - 3); + const int64 width = input.dim_size(input.dims() - 2); + const int64 channels = input.dim_size(input.dims() - 1); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor.shape()), + errors::InvalidArgument("contrast_factor must be scalar: ", + factor.shape().ShortDebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_value.shape()), + errors::InvalidArgument("min_value must be scalar: ", + min_value.shape().ShortDebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_value.shape()), + errors::InvalidArgument("max_value must be scalar: ", + max_value.shape().ShortDebugString())); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + + Tensor mean_values; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + TensorShape(input.shape()), + &mean_values)); + + if (input.NumElements() > 0) { + const int64 batch = input.NumElements() / (height * width * channels); + const int64 shape[4] = {batch, height, width, channels}; + functor::AdjustContrast()( + context->eigen_device(), input.shaped(shape), + factor.scalar(), min_value.scalar(), + max_value.scalar(), mean_values.shaped(shape), + output->shaped(shape)); + } + } +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("AdjustContrast").Device(DEVICE_CPU).TypeConstraint("T"), \ + AdjustContrastOp); + +REGISTER_KERNEL(uint8); +REGISTER_KERNEL(int8); +REGISTER_KERNEL(int16); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +// Forward declarations of the function specializations for GPU (to prevent +// building the GPU versions here, they will be built compiling _gpu.cu.cc). +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void AdjustContrast::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + typename TTypes::ConstScalar contrast_factor, \ + typename TTypes::ConstScalar min_value, \ + typename TTypes::ConstScalar max_value, \ + typename TTypes::Tensor mean_values, \ + typename TTypes::Tensor output); \ + extern template struct AdjustContrast; + +DECLARE_GPU_SPEC(uint8); +DECLARE_GPU_SPEC(int8); +DECLARE_GPU_SPEC(int16); +DECLARE_GPU_SPEC(int32); +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("AdjustContrast").Device(DEVICE_GPU).TypeConstraint("T"), \ + AdjustContrastOp); +REGISTER_GPU_KERNEL(uint8); +REGISTER_GPU_KERNEL(int8); +REGISTER_GPU_KERNEL(int16); +REGISTER_GPU_KERNEL(int32); +REGISTER_GPU_KERNEL(float); +REGISTER_GPU_KERNEL(double); +#undef REGISTER_GPU_KERNEL + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/adjust_contrast_op.h b/tensorflow/core/kernels/adjust_contrast_op.h new file mode 100644 index 0000000000..2182b33c03 --- /dev/null +++ b/tensorflow/core/kernels/adjust_contrast_op.h @@ -0,0 +1,64 @@ +#ifndef TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_ +#define TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_ +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by AdjustContrastOp to do the computations. +template +struct AdjustContrast { + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstScalar contrast_factor, + typename TTypes::ConstScalar min_value, + typename TTypes::ConstScalar max_value, + typename TTypes::Tensor mean_values, + typename TTypes::Tensor output) { + const int batch = input.dimension(0); + const int height = input.dimension(1); + const int width = input.dimension(2); + const int channels = input.dimension(3); + + Eigen::array scalar_broadcast{{batch, height, width, channels}}; +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::array reduction_axis{{1, 2}}; + Eigen::array scalar{{1, 1, 1, 1}}; + Eigen::array broadcast_dims{{1, height, width, 1}}; + Eigen::Tensor::Dimensions reshape_dims{{batch, 1, 1, channels}}; +#else + Eigen::IndexList, Eigen::type2index<2> > + reduction_axis; + Eigen::IndexList, Eigen::type2index<1>, + Eigen::type2index<1>, Eigen::type2index<1> > scalar; + Eigen::IndexList, int, int, Eigen::type2index<1> > + broadcast_dims; + broadcast_dims.set(1, height); + broadcast_dims.set(2, width); + Eigen::IndexList, Eigen::type2index<1>, int> + reshape_dims; + reshape_dims.set(0, batch); + reshape_dims.set(3, channels); +#endif + mean_values.device(d) = input.template cast() + .mean(reduction_axis) + .eval() + .reshape(reshape_dims) + .broadcast(broadcast_dims); + + auto contrast_factor_tensor = + contrast_factor.reshape(scalar).broadcast(scalar_broadcast); + auto adjusted = + (input.template cast() - mean_values) * contrast_factor_tensor + + mean_values; + auto min_bcast = min_value.reshape(scalar).broadcast(scalar_broadcast); + auto max_bcast = max_value.reshape(scalar).broadcast(scalar_broadcast); + // TODO(wicke): This is rather slow and should be re-written as pure cuda. + output.device(d) = adjusted.cwiseMin(max_bcast).cwiseMax(min_bcast); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_ diff --git a/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc b/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc new file mode 100644 index 0000000000..75b177cf4d --- /dev/null +++ b/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc @@ -0,0 +1,43 @@ +#include "tensorflow/core/public/tensor.h" +#include +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +static Graph* BM_AdjustContrast(int batches, int width, int height) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor in(DT_UINT8, TensorShape({batches, width, height, 3})); + in.flat().setRandom(); + Tensor factor(DT_FLOAT, TensorShape({})); + factor.flat().setConstant(1.2); + Tensor min_value(DT_FLOAT, TensorShape({})); + min_value.flat().setConstant(7.); + Tensor max_value(DT_FLOAT, TensorShape({})); + max_value.flat().setConstant(250.); + + Node* ret; + NodeBuilder(g->NewName("n"), "AdjustContrast") + .Input(test::graph::Constant(g, in)) + .Input(test::graph::Constant(g, factor)) + .Input(test::graph::Constant(g, min_value)) + .Input(test::graph::Constant(g, max_value)) + .Finalize(g, &ret); + return g; +} + +#define BM_AdjustContrastDev(DEVICE, B, W, H) \ + static void BM_AdjustContrast_##DEVICE##_##B##_##W##_##H(int iters) { \ + testing::ItemsProcessed(iters* B* W* H * 3); \ + test::Benchmark(#DEVICE, BM_AdjustContrast(B, W, H)).Run(iters); \ + } \ + BENCHMARK(BM_AdjustContrast_##DEVICE##_##B##_##W##_##H); + +// Benchmark results as of cl/106323955 +// BM_AdjustContrast_cpu_1_299_299 3416770 22008951 100 11.6M items/s + +// BM_AdjustContrast_gpu_32_299_299 37117844 45512374 100 179.8M items/s +BM_AdjustContrastDev(cpu, 1, 299, 299) BM_AdjustContrastDev(gpu, 32, 299, 299) + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/adjust_contrast_op_gpu.cu.cc b/tensorflow/core/kernels/adjust_contrast_op_gpu.cu.cc new file mode 100644 index 0000000000..7a9b0726fd --- /dev/null +++ b/tensorflow/core/kernels/adjust_contrast_op_gpu.cu.cc @@ -0,0 +1,22 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/adjust_contrast_op.h" + +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; +template struct functor::AdjustContrast; +template struct functor::AdjustContrast; +template struct functor::AdjustContrast; +template struct functor::AdjustContrast; +template struct functor::AdjustContrast; +template struct functor::AdjustContrast; +template struct functor::AdjustContrast; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/adjust_contrast_op_test.cc b/tensorflow/core/kernels/adjust_contrast_op_test.cc new file mode 100644 index 0000000000..67891e4fa1 --- /dev/null +++ b/tensorflow/core/kernels/adjust_contrast_op_test.cc @@ -0,0 +1,88 @@ +#include "tensorflow/core/framework/allocator.h" +#include +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +class AdjustContrastOpTest : public OpsTestBase { + protected: + void MakeOp() { RequireDefaultOps(); } +}; + +TEST_F(AdjustContrastOpTest, Simple_1113) { + RequireDefaultOps(); + EXPECT_OK(NodeDefBuilder("adjust_constrast_op", "AdjustContrast") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("T", DT_FLOAT) + .Finalize(node_def())); + EXPECT_OK(InitOp()); + AddInputFromArray(TensorShape({1, 1, 1, 3}), {-1, 2, 3}); + AddInputFromArray(TensorShape({}), {1.0}); + AddInputFromArray(TensorShape({}), {0.0}); + AddInputFromArray(TensorShape({}), {2.0}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 3})); + test::FillValues(&expected, {0, 2, 2}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AdjustContrastOpTest, Simple_1223) { + RequireDefaultOps(); + EXPECT_OK(NodeDefBuilder("adjust_constrast_op", "AdjustContrast") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("T", DT_FLOAT) + .Finalize(node_def())); + EXPECT_OK(InitOp()); + AddInputFromArray(TensorShape({1, 2, 2, 3}), + {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}); + AddInputFromArray(TensorShape({}), {0.2}); + AddInputFromArray(TensorShape({}), {0.0}); + AddInputFromArray(TensorShape({}), {10.0}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 2, 3})); + test::FillValues( + &expected, {2.2, 6.2, 10, 2.4, 6.4, 10, 2.6, 6.6, 10, 2.8, 6.8, 10}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AdjustContrastOpTest, Big_99x99x3) { + EXPECT_OK(NodeDefBuilder("adjust_constrast_op", "AdjustContrast") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("T", DT_FLOAT) + .Finalize(node_def())); + EXPECT_OK(InitOp()); + + std::vector values; + for (int i = 0; i < 99 * 99 * 3; ++i) { + values.push_back(i % 255); + } + + AddInputFromArray(TensorShape({1, 99, 99, 3}), values); + AddInputFromArray(TensorShape({}), {0.2}); + AddInputFromArray(TensorShape({}), {0}); + AddInputFromArray(TensorShape({}), {255}); + ASSERT_OK(RunOpKernel()); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/aggregate_ops.cc b/tensorflow/core/kernels/aggregate_ops.cc new file mode 100644 index 0000000000..426e868735 --- /dev/null +++ b/tensorflow/core/kernels/aggregate_ops.cc @@ -0,0 +1,238 @@ +// See docs in ../ops/math_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/aggregate_ops.h" + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/register_types.h" + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class AddNOp : public OpKernel { + public: + explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + if (!ctx->ValidateInputsAreSameShape(this)) return; + + const Tensor& input0 = ctx->input(0); + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input0.shape(), &output)); + auto To = output->flat(); + + const int num = ctx->num_inputs(); + if (num == 1) { + *output = input0; + return; + } + +#define I(IDX) ctx->input(IDX).flat() + +#if defined(PLATFORM_POSIX_ANDROID) || defined(PLATFORM_GOOGLE_ANDROID) + // On Android, we only support additions of two arguments, so we + // can reduce the number of template instantiations. + OP_REQUIRES(ctx, num == 2, + errors::InvalidArgument("Only additions of two arguments " + "supported. Num inputs: ", + num)); + functor::Add2Functor functor2; + functor2(ctx->template eigen_device(), To, I(0), I(1)); +#else + static const int kWidth = 8; + int r = num % kWidth; + + switch (r) { + case 2: { + functor::Add2Functor functor2; + functor2(ctx->template eigen_device(), To, I(0), I(1)); + break; + } + case 3: { + functor::Add3Functor functor3; + functor3(ctx->template eigen_device(), To, I(0), I(1), I(2)); + break; + } + case 4: { + functor::Add4Functor functor4; + functor4(ctx->template eigen_device(), To, I(0), I(1), I(2), + I(3)); + break; + } + case 5: { + functor::Add5Functor functor5; + functor5(ctx->template eigen_device(), To, I(0), I(1), I(2), + I(3), I(4)); + break; + } + case 6: { + functor::Add6Functor functor6; + functor6(ctx->template eigen_device(), To, I(0), I(1), I(2), + I(3), I(4), I(5)); + break; + } + case 7: { + functor::Add7Functor functor7; + functor7(ctx->template eigen_device(), To, I(0), I(1), I(2), + I(3), I(4), I(5), I(6)); + break; + } + case 0: { + functor::Add8Functor functor8; + functor8(ctx->template eigen_device(), To, I(0), I(1), I(2), + I(3), I(4), I(5), I(6), I(7)); + r = 8; + break; + } + case 1: { + functor::Add9Functor functor9; + functor9(ctx->template eigen_device(), To, I(0), I(1), I(2), + I(3), I(4), I(5), I(6), I(7), I(8)); + r = 9; + break; + } + } + + for (; r < num; r += kWidth) { + functor::Add8pFunctor functor8p; + functor8p(ctx->template eigen_device(), To, I(r), I(r + 1), + I(r + 2), I(r + 3), I(r + 4), I(r + 5), I(r + 6), I(r + 7)); + } +#endif // defined(PLATFORM_POSIX_ANDROID) || defined(PLATFORM_GOOGLE_ANDROID) + +#undef I + } +}; + +// Partial specializations for a CPUDevice, that uses the Eigen implementation +// from AddNEigenImpl. +namespace functor { +template +struct Add2Functor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2) { + Add2EigenImpl::Compute(d, out, in1, in2); + } +}; +template +struct Add3Functor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3) { + Add3EigenImpl::Compute(d, out, in1, in2, in3); + } +}; +template +struct Add4Functor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4) { + Add4EigenImpl::Compute(d, out, in1, in2, in3, in4); + } +}; +template +struct Add5Functor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5) { + Add5EigenImpl::Compute(d, out, in1, in2, in3, in4, in5); + } +}; +template +struct Add6Functor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6) { + Add6EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6); + } +}; +template +struct Add7Functor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7) { + Add7EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7); + } +}; + +template +struct Add8Functor { + void operator()( + const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8) { + Add8EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8); + } +}; + +template +struct Add8pFunctor { + void operator()( + const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8) { + Add8pEigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8); + } +}; + +template +struct Add9Functor { + void operator()( + const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8, + typename TTypes::ConstFlat in9) { + Add9EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8, in9); + } +}; + +} // namespace functor + +#define REGISTER_ADDN(type, dev) \ + REGISTER_KERNEL_BUILDER( \ + Name("AddN").Device(DEVICE_##dev).TypeConstraint("T"), \ + AddNOp) + +#define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU) + +TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU); +#undef REGISTER_ADDN_CPU + +#if GOOGLE_CUDA +REGISTER_ADDN(float, GPU); +#endif // GOOGLE_CUDA + +#undef REGISTER_ADDN + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/aggregate_ops.h b/tensorflow/core/kernels/aggregate_ops.h new file mode 100644 index 0000000000..2214901970 --- /dev/null +++ b/tensorflow/core/kernels/aggregate_ops.h @@ -0,0 +1,211 @@ +#ifndef TENSORFLOW_KERNELS_AGGREGATE_OPS_H_ +#define TENSORFLOW_KERNELS_AGGREGATE_OPS_H_ + +// Functor definitions for Aggregate ops, must be compilable by nvcc. + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +template +struct Add2Functor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2); +}; + +template +struct Add2EigenImpl { + static void Compute(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2) { + out.device(d) = in1 + in2; + } +}; + +template +struct Add3Functor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3); +}; + +template +struct Add3EigenImpl { + static void Compute(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3) { + out.device(d) = in1 + in2 + in3; + } +}; + +template +struct Add4Functor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4); +}; + +template +struct Add4EigenImpl { + static void Compute(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4) { + out.device(d) = in1 + in2 + in3 + in4; + } +}; + +template +struct Add5Functor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5); +}; + +template +struct Add5EigenImpl { + static void Compute(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5) { + out.device(d) = in1 + in2 + in3 + in4 + in5; + } +}; + +template +struct Add6Functor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6); +}; + +template +struct Add6EigenImpl { + static void Compute(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6) { + out.device(d) = in1 + in2 + in3 + in4 + in5 + in6; + } +}; + +template +struct Add7Functor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7); +}; + +template +struct Add7EigenImpl { + static void Compute(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7) { + out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7; + } +}; + +template +struct Add8Functor { + void operator()( + const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8); +}; + +template +struct Add8EigenImpl { + static void Compute( + const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8) { + out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8; + } +}; + +// Add8p is like Add8 except the underlying implementation should += +// rather than assign to the output. +template +struct Add8pFunctor { + void operator()( + const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8); +}; + +template +struct Add8pEigenImpl { + static void Compute( + const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8) { + out.device(d) += in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8; + } +}; + +template +struct Add9Functor { + void operator()( + const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8, + typename TTypes::ConstFlat in9); +}; + +template +struct Add9EigenImpl { + static void Compute( + const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8, + typename TTypes::ConstFlat in9) { + out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8 + in9; + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_AGGREGATE_OPS_H_ diff --git a/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc new file mode 100644 index 0000000000..5cf2934ac1 --- /dev/null +++ b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc @@ -0,0 +1,141 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/aggregate_ops.h" + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// Partial specialization for a GPUDevice, that uses the Eigen implementation. +namespace functor { +template +struct Add2Functor { + void operator()(const GPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2) { + Add2EigenImpl::Compute(d, out, in1, in2); + } +}; + +template +struct Add3Functor { + void operator()(const GPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3) { + Add3EigenImpl::Compute(d, out, in1, in2, in3); + } +}; + +template +struct Add4Functor { + void operator()(const GPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4) { + Add4EigenImpl::Compute(d, out, in1, in2, in3, in4); + } +}; + +template +struct Add5Functor { + void operator()(const GPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5) { + Add5EigenImpl::Compute(d, out, in1, in2, in3, in4, in5); + } +}; + +template +struct Add6Functor { + void operator()(const GPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6) { + Add6EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6); + } +}; + +template +struct Add7Functor { + void operator()(const GPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7) { + Add7EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7); + } +}; + +template +struct Add8Functor { + void operator()( + const GPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8) { + Add8EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8); + } +}; + +template +struct Add8pFunctor { + void operator()( + const GPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8) { + Add8pEigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8); + } +}; + +template +struct Add9Functor { + void operator()( + const GPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8, + typename TTypes::ConstFlat in9) { + Add9EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8, in9); + } +}; + +} // end namespace functor + +// Instantiate the GPU implementation for float. +template struct functor::Add2Functor; +template struct functor::Add3Functor; +template struct functor::Add4Functor; +template struct functor::Add5Functor; +template struct functor::Add6Functor; +template struct functor::Add7Functor; +template struct functor::Add8Functor; +template struct functor::Add8pFunctor; +template struct functor::Add9Functor; + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/argmax_op.cc b/tensorflow/core/kernels/argmax_op.cc new file mode 100644 index 0000000000..0845eebf09 --- /dev/null +++ b/tensorflow/core/kernels/argmax_op.cc @@ -0,0 +1,163 @@ +// See docs in ../ops/math_ops.cc. + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow/core/kernels/argmax_op.h" + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class ArgOp : public OpKernel { + public: + explicit ArgOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& dimension = context->input(1); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(dimension.shape()), + errors::InvalidArgument( + "dim must be a scalar, but received tensor of shape: ", + dimension.shape().DebugString())); + + const int32 dim = dimension.scalar()(); + const int input_dims = input.dims(); + + OP_REQUIRES(context, dim >= 0, errors::InvalidArgument("dim must be >= 0")); + OP_REQUIRES(context, dim < input_dims, + errors::InvalidArgument("Minimum tensor rank: ", dim, + " but got: ", input_dims)); + + TensorShape output_shape; + TensorShape input_shape = input.shape(); + for (int d = 0; d < input_dims - 1; ++d) { + output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1)); + } + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + +#define HANDLE_DIM(NDIM) \ + case NDIM: \ + ArgFunctor::Reduce##NDIM(context->eigen_device(), \ + input.tensor(), dim, \ + output->tensor()); \ + break; + + switch (input_dims) { + HANDLE_DIM(1); + HANDLE_DIM(2); + HANDLE_DIM(3); + HANDLE_DIM(4); + HANDLE_DIM(5); + + default: + OP_REQUIRES(context, false, + errors::InvalidArgument( + "ArgOp : Unhandled input dimensions: ", input_dims)); + } + } +#undef HANDLE_DIM + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ArgOp); +}; + +template +class ArgMaxOp : public ArgOp > { + public: + explicit ArgMaxOp(OpKernelConstruction* context) + : ArgOp >(context) {} +}; + +template +class ArgMinOp : public ArgOp > { + public: + explicit ArgMinOp(OpKernelConstruction* context) + : ArgOp >(context) {} +}; + +#define REGISTER_ARGMAX(type) \ + REGISTER_KERNEL_BUILDER(Name("ArgMax") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("dimension"), \ + ArgMaxOp); \ + REGISTER_KERNEL_BUILDER(Name("ArgMin") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("dimension"), \ + ArgMinOp); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX); + +#if GOOGLE_CUDA + +// Forward declarations of the functor specializations for GPU. +namespace functor { + +#define DECLARE_GPU_SPEC(T, Dims) \ + template <> \ + void ArgMax::Reduce##Dims( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const int32 dimension, typename TTypes::Tensor output); \ + template <> \ + void ArgMin::Reduce##Dims( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const int32 dimension, typename TTypes::Tensor output); + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPEC(T, 1); \ + DECLARE_GPU_SPEC(T, 2); \ + DECLARE_GPU_SPEC(T, 3); \ + DECLARE_GPU_SPEC(T, 4); \ + DECLARE_GPU_SPEC(T, 5); + +#define DECLARE_GPU_CLASS(T) \ + extern template struct ArgMax; \ + extern template struct ArgMin; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS); + +#undef DECLARE_GPU_SPECS +#undef DECLARE_GPU_CLASS + +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_ARGMAX_GPU(type) \ + REGISTER_KERNEL_BUILDER(Name("ArgMax") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("dimension"), \ + ArgMaxOp); \ + REGISTER_KERNEL_BUILDER(Name("ArgMin") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("dimension"), \ + ArgMinOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU); + +#undef REGISTER_ARGMAX_GPU + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/argmax_op.h b/tensorflow/core/kernels/argmax_op.h new file mode 100644 index 0000000000..41734f3254 --- /dev/null +++ b/tensorflow/core/kernels/argmax_op.h @@ -0,0 +1,55 @@ +#ifndef TENSORFLOW_KERNELS_ARGMAX_OP_H_ +#define TENSORFLOW_KERNELS_ARGMAX_OP_H_ +// Generator definition for ArgMaxOp, must be compilable by nvcc. + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +namespace functor { + +template +struct ArgMax { +#define DECLARE_COMPUTE_SPEC(Dims) \ + EIGEN_ALWAYS_INLINE static void Reduce##Dims( \ + const Device& d, typename TTypes::ConstTensor input, \ + const int32 dimension, \ + typename TTypes::Tensor output) { \ + output.device(d) = input.argmax(dimension).template cast(); \ + } + + DECLARE_COMPUTE_SPEC(1); + DECLARE_COMPUTE_SPEC(2); + DECLARE_COMPUTE_SPEC(3); + DECLARE_COMPUTE_SPEC(4); + DECLARE_COMPUTE_SPEC(5); + +#undef DECLARE_COMPUTE_SPEC +}; + +template +struct ArgMin { +#define DECLARE_COMPUTE_SPEC(Dims) \ + EIGEN_ALWAYS_INLINE static void Reduce##Dims( \ + const Device& d, typename TTypes::ConstTensor input, \ + const int32 dimension, \ + typename TTypes::Tensor output) { \ + output.device(d) = input.argmin(dimension).template cast(); \ + } + + DECLARE_COMPUTE_SPEC(1); + DECLARE_COMPUTE_SPEC(2); + DECLARE_COMPUTE_SPEC(3); + DECLARE_COMPUTE_SPEC(4); + DECLARE_COMPUTE_SPEC(5); + +#undef DECLARE_COMPUTE_SPEC +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_ARGMAX_OP_H_ diff --git a/tensorflow/core/kernels/argmax_op_gpu.cu.cc b/tensorflow/core/kernels/argmax_op_gpu.cu.cc new file mode 100644 index 0000000000..6c91fc2c86 --- /dev/null +++ b/tensorflow/core/kernels/argmax_op_gpu.cu.cc @@ -0,0 +1,20 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/argmax_op.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +#define DEFINE_GPU_SPEC(T) \ + template struct functor::ArgMax; \ + template struct functor::ArgMin; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC); + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/assign_op.h b/tensorflow/core/kernels/assign_op.h new file mode 100644 index 0000000000..3306f1eeaa --- /dev/null +++ b/tensorflow/core/kernels/assign_op.h @@ -0,0 +1,92 @@ +#ifndef TENSORFLOW_KERNELS_ASSIGN_OP_H_ +#define TENSORFLOW_KERNELS_ASSIGN_OP_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +// TODO(jeff): Get rid of use_exclusive_lock_ option + +// Computes *input[0] = input[1] +class AssignOp : public OpKernel { + public: + explicit AssignOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("use_locking", &use_exclusive_lock_)); + OP_REQUIRES_OK(context, + context->GetAttr("validate_shape", &validate_shape_)); + OP_REQUIRES(context, IsRefType(context->input_type(0)), + errors::InvalidArgument("lhs input needs to be a ref type")); + } + + void Compute(OpKernelContext* context) override { + Tensor rhs = context->input(1); + + // We always return the input ref. + context->forward_ref_input_to_ref_output(0, 0); + + // If the left hand side is not initialized, or the shape of the + // right-hand side is different than the left hand side, we need + // to allocate a new tensor. + { + mutex_lock l(*context->input_ref_mutex(0)); + + Tensor old_lhs = context->mutable_input(0, true); + + if (validate_shape_) { + OP_REQUIRES( + context, old_lhs.shape().IsSameSize(rhs.shape()), + errors::InvalidArgument( + "Assign requires shapes of both tensors to match. lhs shape= ", + old_lhs.shape().ShortDebugString(), " rhs shape= ", + rhs.shape().ShortDebugString())); + } + + const bool same_shape = old_lhs.shape().IsSameSize(rhs.shape()); + if (!old_lhs.IsInitialized() || !same_shape) { + // Create new tensor whose shape matches the right hand side + // and copy then hand off to lhs. + // We can't always know how this value will be used downstream, + // so make conservative assumptions in specifying the memory + // allocation attributes. + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + PersistentTensor copy; + Tensor* copyTensor = nullptr; + OP_REQUIRES_OK( + context, context->allocate_persistent(old_lhs.dtype(), rhs.shape(), + ©, ©Tensor, attr)); + Copy(context, copyTensor, rhs); + context->replace_ref_input(0, *copyTensor, true); + return; + } + + // The tensor has already been initialized and the right hand side + // matches the left hand side's shape. + if (use_exclusive_lock_) { + Copy(context, &old_lhs, rhs); + return; + } + } + + // The tensor has already been initialized and the right hand side + // matches the left hand side's shape. We have been told to do the + // copy outside the lock. + Tensor old_unlocked_lhs = context->mutable_input(0, false); + Copy(context, &old_unlocked_lhs, rhs); + } + + virtual void Copy(OpKernelContext* context, Tensor* lhs, + const Tensor& rhs) = 0; + + bool use_exclusive_lock_; + bool validate_shape_; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_KERNELS_ASSIGN_OP_H_ diff --git a/tensorflow/core/kernels/attention_ops.cc b/tensorflow/core/kernels/attention_ops.cc new file mode 100644 index 0000000000..28763f65a4 --- /dev/null +++ b/tensorflow/core/kernels/attention_ops.cc @@ -0,0 +1,92 @@ +// See docs in ../ops/attention_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks" + +namespace tensorflow { + +class ExtractGlimpseOp : public OpKernel { + public: + explicit ExtractGlimpseOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("normalized", &normalized_)); + OP_REQUIRES_OK(context, context->GetAttr("centered", ¢ered_)); + OP_REQUIRES_OK(context, context->GetAttr("uniform_noise", &uniform_noise_)); + } + + // Expect input tensor of rank 4 with dimensions (batch_size, height, width, + // depth). + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const TensorShape input_shape = input.shape(); + const int32 num_dims = input_shape.dims(); + OP_REQUIRES( + context, num_dims == 4, + errors::InvalidArgument( + "input must be 4-dimensional (batch_size, height, width, depth)", + input_shape.ShortDebugString())); + + const int64 batch_size = input_shape.dim_size(0); + + const Tensor& window_size = context->input(1); + OP_REQUIRES(context, (window_size.shape().dims() == 1) && + window_size.shape().dim_size(0) == 2, + errors::InvalidArgument( + "input must be a vector of size 2 (height, width)", + window_size.shape().ShortDebugString())); + + const int64 output_height = window_size.tensor()(0); + const int64 output_width = window_size.tensor()(1); + TensorShape output_shape = input_shape; + output_shape.set_dim(1, output_height); + output_shape.set_dim(2, output_width); + + const Tensor& offsets = context->input(2); + OP_REQUIRES(context, offsets.shape().dims() == 2, + errors::InvalidArgument("input must be a matrix", + offsets.shape().ShortDebugString())); + OP_REQUIRES(context, offsets.shape().dim_size(0) == batch_size, + errors::InvalidArgument("first dimension should be batch", + offsets.shape().ShortDebugString())); + OP_REQUIRES( + context, offsets.shape().dim_size(1) == 2, + errors::InvalidArgument("second dimension should be of size 2 (y,x)", + offsets.shape().ShortDebugString())); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + + std::vector > offset_vec; + offset_vec.reserve(batch_size); + for (int i = 0; i < batch_size; ++i) { + float offset_y = offsets.tensor()(i, 0); + float offset_x = offsets.tensor()(i, 1); + // Eigen::ExtractGlimpses expects offsets as (x,y), whereas the + // calling TensorFlow operates with (y,x) as indices. + offset_vec.push_back(Eigen::IndexPair(offset_x, offset_y)); + } + + output->tensor().swap_layout().device( + context->eigen_cpu_device()) = + Eigen::ExtractGlimpses(input.tensor().swap_layout(), + output_width, output_height, offset_vec, + normalized_, centered_, uniform_noise_); + } + + private: + bool normalized_; + bool centered_; + bool uniform_noise_; +}; + +REGISTER_KERNEL_BUILDER(Name("ExtractGlimpse").Device(DEVICE_CPU), + ExtractGlimpseOp); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/avgpooling_op.cc b/tensorflow/core/kernels/avgpooling_op.cc new file mode 100644 index 0000000000..26f98ffbcd --- /dev/null +++ b/tensorflow/core/kernels/avgpooling_op.cc @@ -0,0 +1,418 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/avgpooling_op.h" + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/pooling_ops_common.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks" + +#if GOOGLE_CUDA +#include "tensorflow/core/kernels/maxpooling_op_gpu.h" +#include "tensorflow/core/kernels/pooling_ops_common_gpu.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class AvgPoolingOp : public UnaryOp { + public: + explicit AvgPoolingOp(OpKernelConstruction* context) : UnaryOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument( + "Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument( + "Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + PoolParameters params{context, ksize_, stride_, padding_, + tensor_in.shape()}; + if (!context->status().ok()) { + return; + } + OP_REQUIRES(context, params.depth_window == 1, + errors::Unimplemented( + "Non-spatial pooling is not " + "yet supported. Volunteers? :)")); + + // For avgpooling, tensor_in should have 4 dimensions. + OP_REQUIRES(context, tensor_in.dims() == 4, + errors::InvalidArgument("tensor_in must be 4-dimensional")); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 0, params.forward_output_shape(), &output)); + + if (std::is_same::value) { + Eigen::PaddingType pt = BrainPadding2EigenPadding(padding_); + functor::SpatialAvgPooling()( + context->eigen_device(), output->tensor(), + tensor_in.tensor(), params.window_rows, params.window_cols, + params.row_stride, params.col_stride, pt); + } else { + SpatialAvgPool(context, output, tensor_in, params, padding_); + } + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; +}; + +REGISTER_KERNEL_BUILDER(Name("AvgPool") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + AvgPoolingOp); + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void SpatialAvgPooling::operator()( \ + const GPUDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input, int window_rows, \ + int window_cols, int row_stride, int col_stride, \ + const Eigen::PaddingType& padding); \ + extern template struct SpatialAvgPooling; + +DECLARE_GPU_SPEC(float); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNEL_BUILDER(Name("AvgPool") + .Device(DEVICE_GPU) + .TypeConstraint("T"), + AvgPoolingOp); +#endif // GOOGLE_CUDA + +// The operation to compute AvgPool gradients. +// It takes two inputs: +// - The original input tensor shape +// - Backprop tensor for output +// It produces one output: backprop tensor for input. +template +class AvgPoolingGradOp : public OpKernel { + public: + explicit AvgPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument( + "Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument( + "Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in_shape = context->input(0); + const Tensor& out_backprop = context->input(1); + // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements. + OP_REQUIRES(context, tensor_in_shape.dims() == 1 && + tensor_in_shape.NumElements() == 4, + errors::InvalidArgument( + "out_backprop must be 1-dimensional and 4 " + "elements")); + // For avgpooling, out_backprop should have 4 dimensions. + OP_REQUIRES(context, out_backprop.dims() == 4, + errors::InvalidArgument("out_backprop must be 4-dimensional")); + const int64 out_backprop_batch = out_backprop.dim_size(0); + const int64 out_backprop_rows = out_backprop.dim_size(1); + const int64 out_backprop_cols = out_backprop.dim_size(2); + const int64 out_backprop_depth = out_backprop.dim_size(3); + + TensorShape output_shape; + auto shape_vec = tensor_in_shape.vec(); + for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) { + output_shape.AddDim(shape_vec(i)); + } + const int64 in_rows = output_shape.dim_size(1); + const int64 in_cols = output_shape.dim_size(2); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + output->flat().setZero(); + + const int window_rows = ksize_[1]; + const int window_cols = ksize_[2]; + const int depth_window = ksize_[3]; + + const int row_stride = stride_[1]; + const int col_stride = stride_[2]; + + // We (will) use different code for spatial pooling and + // non-spatial pooling. + // + // Spatial pooling is when depth_window = 1 + OP_REQUIRES(context, depth_window == 1, + errors::Unimplemented( + "Non-spatial pooling is not " + "yet supported. Volunteers? :)")); + + int out_height, out_width, pad_rows, pad_cols; + OP_REQUIRES_OK( + context, Get2dOutputSize(in_rows, in_cols, window_rows, window_cols, + row_stride, col_stride, padding_, &out_height, + &out_width, &pad_rows, &pad_cols)); + + const T* out_backprop_ptr = out_backprop.flat().data(); + T* input_backprop_ptr = output->flat().data(); + + for (int64 b = 0; b < out_backprop_batch; ++b) { + for (int64 r = 0; r < out_backprop_rows; ++r) { + // Calculates row broadcast size. For SAME padding, current + // index could be in the padding area, and r*row_stride + + // window_rows could be beyond the input tensor's boundary. In + // such cases, change the starting index and reduce the + // broadcast size. + int rindex, rsize; + OP_REQUIRES_OK(context, + GetBroadcastSize(r, in_rows, window_rows, row_stride, + pad_rows, &rindex, &rsize)); + for (int64 c = 0; c < out_backprop_cols; ++c) { + // Calculates col broadcast size. For SAME padding, current + // index could be in the padding area, and c*col_stride + + // window_cols could be beyond the input tensor's boundary. In + // such cases, change the starting index and reduce the + // broadcast size. + int cindex, csize; + OP_REQUIRES_OK(context, + GetBroadcastSize(c, in_cols, window_cols, col_stride, + pad_cols, &cindex, &csize)); + + T divide_coeff = 1.0 / (rsize * csize); + int64 output_index = + (b * out_backprop_rows + r) * out_backprop_cols + c; + for (int64 r_dst = rindex; r_dst < rindex + rsize; ++r_dst) { + for (int64 c_dst = cindex; c_dst < cindex + csize; ++c_dst) { + int64 input_index = (b * in_rows + r_dst) * in_cols + c_dst; + const T* output_offset = + out_backprop_ptr + output_index * out_backprop_depth; + T* input_offset = + input_backprop_ptr + input_index * out_backprop_depth; + for (int64 d = 0; d < out_backprop_depth; ++d) { + *input_offset += *output_offset * divide_coeff; + ++output_offset; + ++input_offset; + } + } + } + } + } + } + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; +}; + +REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .HostMemory("orig_input_shape"), + AvgPoolingGradOp); +REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .HostMemory("orig_input_shape"), + AvgPoolingGradOp); + +#if GOOGLE_CUDA + +// A CUDNN based AvgPoolingGrad implementation. It includes the padding as the +// candidates for the pooling operation. +template +class AvgPoolingGradOp : public OpKernel { + public: + typedef GPUDevice Device; + + explicit AvgPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in_shape = context->input(0); + const Tensor& out_backprop = context->input(1); + // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements. + OP_REQUIRES( + context, + tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4, + errors::InvalidArgument("out_backprop must be 1-dimensional and 4 " + "elements")); + // For avgpooling, out_backprop should have 4 dimensions. + OP_REQUIRES(context, out_backprop.dims() == 4, + errors::InvalidArgument("out_backprop must be 4-dimensional")); + + TensorShape output_shape; + auto shape_vec = tensor_in_shape.vec(); + for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) { + output_shape.AddDim(shape_vec(i)); + } + + DnnPoolingGradOp::Compute( + context, perftools::gputools::dnn::PoolingMode::kAverage, ksize_, + stride_, padding_, nullptr, nullptr, out_backprop, output_shape); + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; +}; + +REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("orig_input_shape") + .Label("cudnn"), + AvgPoolingGradOp); + +// A custom GPU kernel based AvgPoolingGrad implementation. It includes the +// padding as the candidates for the pooling operation. +template +class AvgPoolingGradOpCustomGPUKernel : public OpKernel { + public: + typedef GPUDevice Device; + + explicit AvgPoolingGradOpCustomGPUKernel(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in_shape = context->input(0); + const Tensor& out_backprop = context->input(1); + // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements. + OP_REQUIRES( + context, + tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4, + errors::InvalidArgument("out_backprop must be 1-dimensional and 4 " + "elements")); + // For avgpooling, out_backprop should have 4 dimensions. + OP_REQUIRES(context, out_backprop.dims() == 4, + errors::InvalidArgument("out_backprop must be 4-dimensional")); + const int64 out_backprop_batch = out_backprop.dim_size(0); + const int64 out_backprop_rows = out_backprop.dim_size(1); + const int64 out_backprop_cols = out_backprop.dim_size(2); + const int64 out_backprop_depth = out_backprop.dim_size(3); + + TensorShape output_shape; + auto shape_vec = tensor_in_shape.vec(); + for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) { + output_shape.AddDim(shape_vec(i)); + } + const int64 in_rows = output_shape.dim_size(1); + const int64 in_cols = output_shape.dim_size(2); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + + const int window_rows = ksize_[1]; + const int window_cols = ksize_[2]; + const int depth_window = ksize_[3]; + + const int row_stride = stride_[1]; + const int col_stride = stride_[2]; + + // We (will) use different code for spatial pooling and + // non-spatial pooling. + // + // Spatial pooling is when depth_window = 1 + OP_REQUIRES(context, depth_window == 1, + errors::Unimplemented("Non-spatial pooling is not " + "yet supported. Volunteers? :)")); + + int out_height, out_width, pad_rows, pad_cols; + OP_REQUIRES_OK( + context, Get2dOutputSize(in_rows, in_cols, window_rows, window_cols, + row_stride, col_stride, padding_, &out_height, + &out_width, &pad_rows, &pad_cols)); + + RunAvePoolBackwardNHWC(out_backprop.flat().data(), // top_diff + out_backprop_batch, // num + in_rows, // height + in_cols, // width + out_backprop_depth, // channels + out_backprop_rows, // pooled_height + out_backprop_cols, // pooled_width + window_rows, // kernel_h + window_cols, // kernel_w + row_stride, // stride_h + col_stride, // stride_w + pad_rows, // pad_t + pad_cols, // pad_l + output->flat().data(), // bottom_diff + context->eigen_gpu_device()); // d + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; +}; + +REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("orig_input_shape"), + AvgPoolingGradOpCustomGPUKernel); + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/avgpooling_op.h b/tensorflow/core/kernels/avgpooling_op.h new file mode 100644 index 0000000000..38f0eb97e5 --- /dev/null +++ b/tensorflow/core/kernels/avgpooling_op.h @@ -0,0 +1,58 @@ +#ifndef TENSORFLOW_KERNELS_AVGPOOLING_OP_H_ +#define TENSORFLOW_KERNELS_AVGPOOLING_OP_H_ +// Functor definition for AvgPoolingOp, must be compilable by nvcc. + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks" + +namespace tensorflow { +namespace functor { + +template +struct SpatialAvgPooling { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, int window_rows, + int window_cols, int row_stride, int col_stride, + const Eigen::PaddingType& padding) { + // Because we swap the layout, we swap the row/cols as well + output.swap_layout().device(d) = + Eigen::SpatialAvgPooling(input.swap_layout(), window_cols, window_rows, + col_stride, row_stride, padding); + } +}; + +} // namespace functor + +typedef Eigen::GpuDevice GPUDevice; + +// Lauch a custom GPU kernels from Yanqing for the avgpooling backward operation +// that works NHWC data formats. +// Arguments: +// top_diff: backprop to the output of the pooling layer +// num: number of input batches +// height: input height +// width: input width +// channels: number of input channels +// pooled_height: the height of the output to the pooling layer +// pooled_width: the width of the output to the pooling layer +// kernel_h: the height of the pooling kernel +// kernel_w: the width of the pooling kernel +// stride_h: the height of the vertical stride +// stride_w: the width of the horizontal stride +// pad_t: padding size to the top side +// pad_l: padding size to the left side +// bottom_diff: backprop to the input of the pooling layer. +template +bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num, + const int height, const int width, + const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, + const int pad_l, T* const bottom_diff, + const GPUDevice& d); + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_AVGPOOLING_OP_H_ diff --git a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc new file mode 100644 index 0000000000..ec84ee6862 --- /dev/null +++ b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc @@ -0,0 +1,101 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include +#include + +#include "tensorflow/core/kernels/avgpooling_op.h" + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::SpatialAvgPooling; + +DEFINE_GPU_KERNELS(float) + +#undef DEFINE_GPU_KERNELS + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +static const int CAFFE_CUDA_NUM_THREADS = 1024; + +template +__global__ void AvePoolBackwardNHWC(const int nthreads, + const dtype* const top_diff, const int num, + const int height, const int width, + const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, + const int pad_l, dtype* const bottom_diff) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // find out the local index + // find out the local offset + const int c = index % channels; + const int w = index / channels % width + pad_l; + const int h = (index / channels / width) % height + pad_t; + const int n = index / channels / width / height; + const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; + const int phend = min(h / stride_h + 1, pooled_height); + const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; + const int pwend = min(w / stride_w + 1, pooled_width); + dtype gradient = 0; + const dtype* const top_diff_slice = + top_diff + n * pooled_height * pooled_width * channels + c; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + // figure out the pooling size + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + int pool_size = (hend - hstart) * (wend - wstart); + gradient += + top_diff_slice[(ph * pooled_width + pw) * channels] / pool_size; + } + } + bottom_diff[index] = gradient; + } +} + +template +bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num, + const int height, const int width, + const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, + const int pad_l, T* const bottom_diff, + const GPUDevice& d) { + int x_size = num * height * width * channels; + int thread_per_block = + std::min(CAFFE_CUDA_NUM_THREADS, d.maxCudaThreadsPerMultiProcessor()); + int block_count = (x_size + thread_per_block - 1) / thread_per_block; + AvePoolBackwardNHWC<<>>( + x_size, top_diff, num, height, width, channels, pooled_height, + pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_t, + bottom_diff); + + return d.ok(); +} + +template bool RunAvePoolBackwardNHWC( + const float* const top_diff, const int num, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + float* const bottom_diff, const GPUDevice& d); + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/batch_matmul_op.cc b/tensorflow/core/kernels/batch_matmul_op.cc new file mode 100644 index 0000000000..349aac0158 --- /dev/null +++ b/tensorflow/core/kernels/batch_matmul_op.cc @@ -0,0 +1,260 @@ +// See docs in ../ops/math_ops.cc. + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/util/work_sharder.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/stream_executor/stream.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +struct LaunchBatchMatMul; + +template +struct LaunchBatchMatMul { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) { + auto Tx = in_x.tensor(); + auto Ty = in_y.tensor(); + auto Tz = out->tensor(); + + // Shards "n"-matmuls into "num" shards. Each shard is + // dispatched to a thread. + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + const int64 num_units = in_x.dim_size(0); + const int64 cost_per_unit = + in_x.dim_size(0) * in_x.dim_size(1) * out->dim_size(2); + Shard(worker_threads.num_threads, worker_threads.workers, num_units, + cost_per_unit, [&Tx, &Ty, adj_x, adj_y, &Tz](int start, int limit) { + LaunchBatchMatMul::Run(Tx, Ty, adj_x, adj_y, Tz, + start, limit); + }); + } + + template + static void Run(In Tx, In Ty, bool adj_x, bool adj_y, Out Tz, int start, + int limit) { + Eigen::array, 1> contract_pairs; + + Eigen::internal::scalar_conjugate_op conj; + if (!adj_x && !adj_y) { + for (int i = start; i < limit; ++i) { + auto x = Tx.template chip<0>(i); + auto y = Ty.template chip<0>(i); + auto z = Tz.template chip<0>(i); + contract_pairs[0] = Eigen::IndexPair(1, 0); + z = x.contract(y, contract_pairs); // matmul + } + } else if (!adj_x && adj_y) { + for (int i = start; i < limit; ++i) { + auto x = Tx.template chip<0>(i); + auto y = Ty.template chip<0>(i).unaryExpr(conj); + auto z = Tz.template chip<0>(i); + contract_pairs[0] = Eigen::IndexPair(1, 1); + z = x.contract(y, contract_pairs); // matmul + } + } else if (adj_x && !adj_y) { + for (int i = start; i < limit; ++i) { + auto x = Tx.template chip<0>(i).unaryExpr(conj); + auto y = Ty.template chip<0>(i); + auto z = Tz.template chip<0>(i); + contract_pairs[0] = Eigen::IndexPair(0, 0); + z = x.contract(y, contract_pairs); // matmul + } + } else { + for (int i = start; i < limit; ++i) { + auto x = Tx.template chip<0>(i).unaryExpr(conj); + auto y = Ty.template chip<0>(i).unaryExpr(conj); + auto z = Tz.template chip<0>(i); + contract_pairs[0] = Eigen::IndexPair(0, 1); + z = x.contract(y, contract_pairs); // matmul + } + } + } +}; + +#if GOOGLE_CUDA + +namespace { +template +perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory) { + perftools::gputools::DeviceMemoryBase wrapped(const_cast(cuda_memory)); + perftools::gputools::DeviceMemory typed(wrapped); + return typed; +} +} // namespace + +template +struct LaunchBatchMatMul { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) { + perftools::gputools::blas::Transpose trans[] = { + perftools::gputools::blas::Transpose::kNoTranspose, + perftools::gputools::blas::Transpose::kTranspose}; + const uint64 m = in_x.dim_size(adj_x ? 2 : 1); + const uint64 k = in_x.dim_size(adj_x ? 1 : 2); + const uint64 n = in_y.dim_size(adj_y ? 1 : 2); + const uint64 batch_size = in_x.dim_size(0); + auto blas_transpose_a = trans[adj_x]; + auto blas_transpose_b = trans[adj_y]; + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + typedef perftools::gputools::DeviceMemory DeviceMemoryType; + std::vector a_device_memory; + std::vector b_device_memory; + std::vector c_device_memory; + std::vector a_ptrs; + std::vector b_ptrs; + std::vector c_ptrs; + a_device_memory.reserve(batch_size); + b_device_memory.reserve(batch_size); + c_device_memory.reserve(batch_size); + a_ptrs.reserve(batch_size); + b_ptrs.reserve(batch_size); + c_ptrs.reserve(batch_size); + auto* a_base_ptr = in_x.template flat().data(); + auto* b_base_ptr = in_y.template flat().data(); + auto* c_base_ptr = out->template flat().data(); + for (int64 i = 0; i < batch_size; ++i) { + a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); + b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); + c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + } + + // Cublas does + // C = A x B + // where A, B and C are assumed to be in column major. + // We want the output to be in row-major, so we can compute + // C' = B' x A' (' stands for transpose) + bool blas_launch_status = + stream->ThenBlasGemmBatched(blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), b_ptrs, + adj_y ? k : n, a_ptrs, adj_x ? m : k, + static_cast(0.0), c_ptrs, n, + batch_size) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas SGEMMBatched launch failed : a.shape=", + in_x.shape().DebugString(), ", b.shape=", in_y.shape().DebugString(), + ", m=", m, ", n=", n, ", k=", k, ", batch_size=", batch_size)); + } + } +}; + +#endif // GOOGLE_CUDA + +template +class BatchMatMul : public OpKernel { + public: + explicit BatchMatMul(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); + OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); + } + + virtual ~BatchMatMul() {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& in0 = ctx->input(0); + const Tensor& in1 = ctx->input(1); + OP_REQUIRES(ctx, in0.dims() == in1.dims(), + errors::InvalidArgument("In[0] and In[1] has different ndims: ", + in0.shape().ShortDebugString(), " vs. ", + in1.shape().ShortDebugString())); + const int ndims = in0.dims(); + OP_REQUIRES( + ctx, ndims >= 3, + errors::InvalidArgument("In[0] and In[1] ndims must be >= 3: ", ndims)); + TensorShape out_shape; + for (int i = 0; i < ndims - 2; ++i) { + OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i), + errors::InvalidArgument("In[0].dim(", i, ") and In[1].dim(", + i, ") must be the same: ", + in0.shape().DebugString(), " vs ", + in1.shape().DebugString())); + out_shape.AddDim(in0.dim_size(i)); + } + auto n = out_shape.num_elements(); + auto d0 = in0.dim_size(ndims - 2); + auto d1 = in0.dim_size(ndims - 1); + Tensor in0_reshaped; + CHECK(in0_reshaped.CopyFrom(in0, TensorShape({n, d0, d1}))); + auto d2 = in1.dim_size(ndims - 2); + auto d3 = in1.dim_size(ndims - 1); + Tensor in1_reshaped; + CHECK(in1_reshaped.CopyFrom(in1, TensorShape({n, d2, d3}))); + if (adj_x_) std::swap(d0, d1); + if (adj_y_) std::swap(d2, d3); + OP_REQUIRES(ctx, d1 == d2, + errors::InvalidArgument( + "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ", + in0.shape().ShortDebugString(), " ", + in1.shape().ShortDebugString(), " ", adj_x_, " ", adj_y_)); + out_shape.AddDim(d0); + out_shape.AddDim(d3); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); + if (out->NumElements() == 0) { + return; + } + if (in0.NumElements() == 0 || in1.NumElements() == 0) { + functor::SetZeroFunctor f; + f(ctx->eigen_device(), out->flat()); + return; + } + Tensor out_reshaped; + CHECK(out_reshaped.CopyFrom(*out, TensorShape({n, d0, d3}))); + LaunchBatchMatMul::Launch(ctx, in0_reshaped, in1_reshaped, + adj_x_, adj_y_, &out_reshaped); + } + + private: + bool adj_x_; + bool adj_y_; +}; + +#define REGISTER_CPU(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ + BatchMatMul) + +#define REGISTER_GPU(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint("T"), \ + BatchMatMul) + +REGISTER_CPU(float); +REGISTER_CPU(double); +REGISTER_CPU(int32); +REGISTER_CPU(complex64); + +#ifdef GOOGLE_CUDA +// TODO(kalakris): The GPU implementation is currently disabled due to issues +// encountered in practice. See b/24534272. +// REGISTER_GPU(float); +#endif // GOOGLE_CUDA + +#undef REGISTER_CPU +#undef REGISTER_GPU +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc new file mode 100644 index 0000000000..c67c921631 --- /dev/null +++ b/tensorflow/core/kernels/batch_norm_op.cc @@ -0,0 +1,223 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/batch_norm_op.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class BatchNormOp : public OpKernel { + public: + explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("variance_epsilon", &variance_epsilon_)); + OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization", + &scale_after_normalization_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& mean = context->input(1); + const Tensor& var = context->input(2); + const Tensor& beta = context->input(3); + const Tensor& gamma = context->input(4); + + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().ShortDebugString())); + OP_REQUIRES(context, mean.dims() == 1, + errors::InvalidArgument("mean must be 1-dimensional", + mean.shape().ShortDebugString())); + OP_REQUIRES(context, var.dims() == 1, + errors::InvalidArgument("var must be 1-dimensional", + var.shape().ShortDebugString())); + OP_REQUIRES(context, beta.dims() == 1, + errors::InvalidArgument("beta must be 1-dimensional", + beta.shape().ShortDebugString())); + OP_REQUIRES(context, gamma.dims() == 1, + errors::InvalidArgument("gamma must be 1-dimensional", + gamma.shape().ShortDebugString())); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + + functor::BatchNorm()( + context->eigen_device(), input.tensor(), mean.vec(), + var.vec(), beta.vec(), gamma.vec(), variance_epsilon_, + scale_after_normalization_, output->tensor()); + } + + private: + float variance_epsilon_; + bool scale_after_normalization_; +}; + +template +class BatchNormGradOp : public OpKernel { + public: + explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("variance_epsilon", &variance_epsilon_)); + OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization", + &scale_after_normalization_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& mean = context->input(1); + const Tensor& var = context->input(2); + const Tensor& gamma = context->input(3); + const Tensor& out_backprop = context->input(4); + + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().ShortDebugString())); + OP_REQUIRES(context, mean.dims() == 1, + errors::InvalidArgument("mean must be 1-dimensional", + mean.shape().ShortDebugString())); + OP_REQUIRES(context, var.dims() == 1, + errors::InvalidArgument("var must be 1-dimensional", + var.shape().ShortDebugString())); + OP_REQUIRES(context, gamma.dims() == 1, + errors::InvalidArgument("gamma must be 1-dimensional", + gamma.shape().ShortDebugString())); + OP_REQUIRES( + context, out_backprop.dims() == 4, + errors::InvalidArgument("out_backprop must be 4-dimensional", + out_backprop.shape().ShortDebugString())); + + Tensor* dx = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &dx)); + Tensor* dm = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, mean.shape(), &dm)); + Tensor* dv = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(2, var.shape(), &dv)); + Tensor* db = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db)); + Tensor* dg = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg)); + + // Scratch buffer of [depth] dimension, aka the 4th dimension of input, + // which is dim_size(3), for calculating various combinations of + // (var + epsilon). + Tensor scratch1; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({input.dim_size(3)}), &scratch1)); + + // Scratch buffer of [depth] dimension for saving intermediate calculation + // values. + Tensor scratch2; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({input.dim_size(3)}), &scratch2)); + + functor::BatchNormGrad()( + context->eigen_device(), input.tensor(), mean.vec(), + var.vec(), gamma.vec(), out_backprop.tensor(), + variance_epsilon_, scale_after_normalization_, dx->tensor(), + dm->vec(), dv->vec(), db->vec(), dg->vec(), + scratch1.vec(), scratch2.vec()); + } + + private: + float variance_epsilon_; + bool scale_after_normalization_; +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + BatchNormOp); + +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void BatchNorm::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + typename TTypes::ConstVec mean, typename TTypes::ConstVec var, \ + typename TTypes::ConstVec beta, typename TTypes::ConstVec gamma, \ + float variance_epsilon, bool scale_after_normalization, \ + typename TTypes::Tensor output); \ + extern template struct BatchNorm; + +#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); + +DECLARE_GPU_SPECS(float); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + BatchNormOp); + +REGISTER_GPU_KERNEL(float); +#undef REGISTER_GPU_KERNEL + +#endif // GOOGLE_CUDA + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + BatchNormGradOp); + +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void BatchNormGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + typename TTypes::ConstVec mean, typename TTypes::ConstVec var, \ + typename TTypes::ConstVec gamma, \ + typename TTypes::ConstTensor out_backprop, float variance_epsilon, \ + bool scale_after_normalization, typename TTypes::Tensor dx, \ + typename TTypes::Vec dm, typename TTypes::Vec dv, \ + typename TTypes::Vec db, typename TTypes::Vec dg, \ + typename TTypes::Vec scratch1, typename TTypes::Vec scratch2); \ + extern template struct BatchNormGrad; + +#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); + +DECLARE_GPU_SPECS(float); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + BatchNormGradOp); + +REGISTER_GPU_KERNEL(float); +#undef REGISTER_GPU_KERNEL + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_norm_op.h b/tensorflow/core/kernels/batch_norm_op.h new file mode 100644 index 0000000000..5981e58460 --- /dev/null +++ b/tensorflow/core/kernels/batch_norm_op.h @@ -0,0 +1,133 @@ +#ifndef TENSORFLOW_KERNELS_BATCH_NORM_OP_H_ +#define TENSORFLOW_KERNELS_BATCH_NORM_OP_H_ +// Functor definition for BatchNormOp, must be compilable by nvcc. +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by BatchNormOp to do the computations. +template +struct BatchNorm { + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstVec mean, + typename TTypes::ConstVec var, + typename TTypes::ConstVec beta, + typename TTypes::ConstVec gamma, float variance_epsilon, + bool scale_after_normalization, + typename TTypes::Tensor output) { + const int depth = mean.dimension(0); + const int rest_size = input.size() / depth; + + Eigen::DSizes rest_by_depth(rest_size, depth); +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::DSizes rest_by_one(rest_size, 1); + Eigen::DSizes one_by_depth(1, depth); + Eigen::DSizes depth_by_one(depth, 1); +#else + Eigen::IndexList > rest_by_one; + rest_by_one.set(0, rest_size); + Eigen::IndexList, int> one_by_depth; + one_by_depth.set(1, depth); + Eigen::IndexList > depth_by_one; + depth_by_one.set(0, depth); +#endif + if (scale_after_normalization) { + output.reshape(rest_by_depth).device(d) = + (input.reshape(rest_by_depth) - + mean.reshape(one_by_depth).broadcast(rest_by_one)) * + ((var + var.constant(variance_epsilon)).rsqrt() * gamma) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one) + + beta.reshape(one_by_depth).broadcast(rest_by_one); + } else { + output.reshape(rest_by_depth).device(d) = + (input.reshape(rest_by_depth) - + mean.reshape(one_by_depth).broadcast(rest_by_one)) * + ((var + var.constant(variance_epsilon)).rsqrt()) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one) + + beta.reshape(one_by_depth).broadcast(rest_by_one); + } + } +}; + +template +struct BatchNormGrad { + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstVec mean, + typename TTypes::ConstVec var, + typename TTypes::ConstVec gamma, + typename TTypes::ConstTensor out_backprop, + float variance_epsilon, bool scale_after_normalization, + typename TTypes::Tensor dx, typename TTypes::Vec dm, + typename TTypes::Vec dv, typename TTypes::Vec db, + typename TTypes::Vec dg, typename TTypes::Vec scratch1, + typename TTypes::Vec scratch2) { + const int depth = mean.dimension(0); + const int rest_size = input.size() / depth; + + typedef typename TTypes::ConstVec::Index Index; + Eigen::DSizes rest_by_depth(rest_size, depth); + Eigen::DSizes rest_by_one(rest_size, 1); + Eigen::DSizes one_by_depth(1, depth); + + // db = out_backprop + // + // dg = out_backprop * ((x - m) * rsqrt(v + epsilon)) + // + // dv = sum_over_rest(out_backprop * gamma * (x - m)) * + // (-1/2) * (v + epsilon) ^ (-3/2) + // + // dm = sum_over_rest(out_backprop * gamma) * (-1 / rsqrt(v + epsilon)) + // + // dx = out_backprop * (gamma * rsqrt(v + epsilon)) + Eigen::array reduction_axis; + reduction_axis[0] = 0; // Reduces on first dimension. + + db.device(d) = out_backprop.reshape(rest_by_depth).sum(reduction_axis); + + // scratch1 = rsqrt(v + epsilon) + scratch1.device(d) = (var + var.constant(variance_epsilon)).rsqrt(); + + // scratch2 = sum_over_rest(out_backprop * (x - m)) + scratch2.device(d) = (out_backprop.reshape(rest_by_depth) * + (input.reshape(rest_by_depth) - + mean.reshape(one_by_depth).broadcast(rest_by_one))) + .sum(reduction_axis); + + if (scale_after_normalization) { + dx.reshape(rest_by_depth).device(d) = + out_backprop.reshape(rest_by_depth) * ((scratch1 * gamma) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one)); + dm.device(d) = -db * (scratch1 * gamma).eval(); + dg.device(d) = scratch2 * scratch1; + } else { + dx.reshape(rest_by_depth).device(d) = + out_backprop.reshape(rest_by_depth) * + scratch1.reshape(one_by_depth).broadcast(rest_by_one); + dm.device(d) = -db * scratch1; + dg.device(d) = dg.constant(static_cast(0.0)); // Gamma is not learned. + } + + // scratch1 = - 1/2 * (var + epsilon) ^ (-3/2) + scratch1.device(d) = scratch1 * scratch1.constant(static_cast(-0.5f)) / + (var + var.constant(variance_epsilon)); + + if (scale_after_normalization) { + dv.device(d) = scratch2 * (scratch1 * gamma).eval(); + } else { + dv.device(d) = scratch2 * scratch1; + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_BATCH_NORM_OP_H_ diff --git a/tensorflow/core/kernels/batch_norm_op_gpu.cu.cc b/tensorflow/core/kernels/batch_norm_op_gpu.cu.cc new file mode 100644 index 0000000000..02e0eeecfa --- /dev/null +++ b/tensorflow/core/kernels/batch_norm_op_gpu.cu.cc @@ -0,0 +1,17 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/batch_norm_op.h" + +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; +template struct functor::BatchNorm; +template struct functor::BatchNormGrad; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/bcast_ops.cc b/tensorflow/core/kernels/bcast_ops.cc new file mode 100644 index 0000000000..bb1492e5b4 --- /dev/null +++ b/tensorflow/core/kernels/bcast_ops.cc @@ -0,0 +1,71 @@ +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { + +// Given shapes of two tensors, computes the reduction indices for the +// gradient computation. +// +// TODO(zhifengc): +// 1. Adds support for n-ary (n >= 2). +class BCastGradArgsOp : public OpKernel { + public: + explicit BCastGradArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK( + ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32, DT_INT32})); + } + + void Compute(OpKernelContext* ctx) override { + OP_REQUIRES( + ctx, ctx->num_inputs() == 2, + errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); + gtl::InlinedVector shapes; + for (int i = 0; i < ctx->num_inputs(); ++i) { + const Tensor& in = ctx->input(i); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()), + errors::InvalidArgument("In[", i, "] must be a vector.", + in.shape().ShortDebugString())); + BCast::Vec vec; + for (int64 i = 0; i < in.NumElements(); ++i) { + vec.push_back(in.vec()(i)); + } + shapes.push_back(vec); + } + BCast bcast(shapes[0], shapes[1]); + OP_REQUIRES(ctx, bcast.IsValid(), + errors::InvalidArgument( + "Incompatible shapes: [", str_util::Join(shapes[0], ","), + "] vs. [", str_util::Join(shapes[1], ","), "]")); + Output(ctx, 0, bcast.grad_x_reduce_idx()); + Output(ctx, 1, bcast.grad_y_reduce_idx()); + } + + private: + void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) { + const int len = v.size(); + Tensor* o = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o)); + for (int i = 0; i < len; ++i) o->flat()(i) = v[i]; + } + + TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp); +}; + +REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs") + .Device(DEVICE_CPU) + .HostMemory("s0") + .HostMemory("s1") + .HostMemory("r0") + .HostMemory("r1"), + BCastGradArgsOp); +REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs") + .Device(DEVICE_GPU) + .HostMemory("s0") + .HostMemory("s1") + .HostMemory("r0") + .HostMemory("r1"), + BCastGradArgsOp); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc new file mode 100644 index 0000000000..68737f6c2d --- /dev/null +++ b/tensorflow/core/kernels/bias_op.cc @@ -0,0 +1,112 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/bias_op.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class BiasOp : public BinaryOp { + public: + explicit BiasOp(OpKernelConstruction* context) : BinaryOp(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& bias = context->input(1); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()), + errors::InvalidArgument("Input tensor must be at least 2D: ", + input.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()), + errors::InvalidArgument("Biases must be 1D: ", + bias.shape().DebugString())); + const auto last_dim = input.shape().dims() - 1; + OP_REQUIRES( + context, bias.shape().dim_size(0) == input.shape().dim_size(last_dim), + errors::InvalidArgument( + "Must provide as many biases as the last dimension " + "of the input tensor: ", + bias.shape().DebugString(), " vs. ", input.shape().DebugString())); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + + switch (input.shape().dims()) { + case 2: + Compute<2>(context, input, bias, output); + break; + case 3: + Compute<3>(context, input, bias, output); + break; + case 4: + Compute<4>(context, input, bias, output); + break; + case 5: + Compute<5>(context, input, bias, output); + break; + default: + OP_REQUIRES(context, false, + errors::InvalidArgument("Only ranks up to 5 supported: ", + input.shape().DebugString())); + } + } + + // Add biases for an input matrix of rank Dims, by using the Bias. + template + void Compute(OpKernelContext* ctx, const Tensor& input, const Tensor& bias, + Tensor* output) { + functor::Bias functor; + functor(ctx->eigen_device(), input.tensor(), bias.vec(), + output->tensor()); + } +}; + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAdd").Device(DEVICE_CPU).TypeConstraint("T"), \ + BiasOp); + +TF_CALL_NUMBER_TYPES(REGISTER_KERNEL); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T, Dims) \ + template <> \ + void Bias::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + typename TTypes::ConstVec bias, \ + typename TTypes::Tensor output); \ + extern template struct Bias; + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPEC(T, 2); \ + DECLARE_GPU_SPEC(T, 3); \ + DECLARE_GPU_SPEC(T, 4); \ + DECLARE_GPU_SPEC(T, 5); + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAdd").Device(DEVICE_GPU).TypeConstraint("T"), \ + BiasOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/bias_op.h b/tensorflow/core/kernels/bias_op.h new file mode 100644 index 0000000000..513406d251 --- /dev/null +++ b/tensorflow/core/kernels/bias_op.h @@ -0,0 +1,41 @@ +#ifndef TENSORFLOW_KERNELS_BIAS_OP_H_ +#define TENSORFLOW_KERNELS_BIAS_OP_H_ +// Functor definition for BiasOp, must be compilable by nvcc. + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by BiasOp to do the computations. +template +struct Bias { + // Add "bias" to "input", broadcasting it on all dimensions but the last one. + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstVec bias, + typename TTypes::Tensor output) { + const int bias_size = bias.dimension(0); + const int rest_size = input.size() / bias_size; + + Eigen::DSizes rest_by_bias(rest_size, bias_size); +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::DSizes rest_by_one(rest_size, 1); + Eigen::DSizes one_by_bias(1, bias_size); +#else + Eigen::IndexList > rest_by_one; + rest_by_one.set(0, rest_size); + Eigen::IndexList, int> one_by_bias; + one_by_bias.set(1, bias_size); +#endif + + output.reshape(rest_by_bias).device(d) = + input.reshape(rest_by_bias) + + bias.reshape(one_by_bias).broadcast(rest_by_one); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_BIAS_OP_H_ diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc new file mode 100644 index 0000000000..d3377b3ce8 --- /dev/null +++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc @@ -0,0 +1,23 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/bias_op.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// Definition of the GPU implementations declared in bias_op.cc. +#define DEFINE_GPU_SPECS(T) \ + template struct functor::Bias; \ + template struct functor::Bias; \ + template struct functor::Bias; \ + template struct functor::Bias; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc new file mode 100644 index 0000000000..cd5fde37a6 --- /dev/null +++ b/tensorflow/core/kernels/candidate_sampler_ops.cc @@ -0,0 +1,243 @@ +// See docs in ../ops/candidate_sampling_ops.cc. + +#define EIGEN_USE_THREADS + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/range_sampler.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/util/guarded_philox_random.h" + +namespace tensorflow { + +class BaseCandidateSamplerOp : public OpKernel { + public: + explicit BaseCandidateSamplerOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("num_sampled", &num_sampled_)); + OP_REQUIRES_OK(context, context->GetAttr("num_true", &num_true_)); + OP_REQUIRES_OK(context, context->GetAttr("unique", &unique_)); + OP_REQUIRES_OK(context, generator_.Init(context)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& true_classes = context->input(0); + OP_REQUIRES(context, true_classes.dims() == 2, + errors::InvalidArgument("true_classes must be a matrix")); + const int32 batch_size = true_classes.dim_size(0); + OP_REQUIRES(context, true_classes.dim_size(1) == num_true_, + errors::InvalidArgument("true_classes must have " + "num_true columns")); + + // Output candidates and expected_count. + Tensor* out_sampled_candidates = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({num_sampled_}), + &out_sampled_candidates)); + + Tensor* out_true_expected_count = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 1, TensorShape({batch_size, num_true_}), + &out_true_expected_count)); + Tensor* out_sampled_expected_count = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(2, TensorShape({num_sampled_}), + &out_sampled_expected_count)); + + gtl::ArraySlice true_candidate(true_classes.matrix().data(), + batch_size * num_true_); + gtl::MutableArraySlice sampled_candidate( + out_sampled_candidates->vec().data(), num_sampled_); + gtl::MutableArraySlice true_expected_count( + out_true_expected_count->matrix().data(), + batch_size * num_true_); + gtl::MutableArraySlice sampled_expected_count( + out_sampled_expected_count->vec().data(), num_sampled_); + + CHECK(sampler_) << "CandidateSamplerOp did not set sampler_"; + + // Approximately conservatively estimate the number of samples required. + // In cases where rejection sampling is used we may occasionally use more + // samples than expected, which will result in reused random bits. + const int64 samples32 = 2048 * num_sampled_; + + // Pick sampled candidates. + auto local_gen = generator_.ReserveSamples32(samples32); + random::SimplePhilox random(&local_gen); + sampler_->SampleBatchGetExpectedCount(&random, unique_, &sampled_candidate, + &sampled_expected_count, + true_candidate, &true_expected_count); + + if (sampler_->NeedsUpdates()) { + sampler_->Update(true_candidate); + } + } + + protected: + void set_sampler(RangeSampler* sampler) { sampler_.reset(sampler); } + + private: + int32 num_true_; + int32 num_sampled_; + bool unique_; + std::unique_ptr sampler_; + GuardedPhiloxRandom generator_; +}; + +template +class SimpleCandidateSamplerOp : public BaseCandidateSamplerOp { + public: + explicit SimpleCandidateSamplerOp(OpKernelConstruction* context) + : BaseCandidateSamplerOp(context) { + int64 range_max; + OP_REQUIRES_OK(context, context->GetAttr("range_max", &range_max)); + set_sampler(new RangeSamplerType(range_max)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("UniformCandidateSampler").Device(DEVICE_CPU), + SimpleCandidateSamplerOp); + +REGISTER_KERNEL_BUILDER(Name("LogUniformCandidateSampler").Device(DEVICE_CPU), + SimpleCandidateSamplerOp); + +REGISTER_KERNEL_BUILDER(Name("LearnedUnigramCandidateSampler") + .Device(DEVICE_CPU), + SimpleCandidateSamplerOp); + +REGISTER_KERNEL_BUILDER(Name("ThreadUnsafeUnigramCandidateSampler") + .Device(DEVICE_CPU), + SimpleCandidateSamplerOp); + +class AllCandidateSamplerOp : public BaseCandidateSamplerOp { + public: + explicit AllCandidateSamplerOp(OpKernelConstruction* context) + : BaseCandidateSamplerOp(context) { + int64 range_max; + OP_REQUIRES_OK(context, context->GetAttr("num_sampled", &range_max)); + set_sampler(new AllSampler(range_max)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("AllCandidateSampler").Device(DEVICE_CPU), + AllCandidateSamplerOp); + +class FixedUnigramCandidateSamplerOp : public BaseCandidateSamplerOp { + public: + explicit FixedUnigramCandidateSamplerOp(OpKernelConstruction* context) + : BaseCandidateSamplerOp(context) { + int64 range_max; + OP_REQUIRES_OK(context, context->GetAttr("range_max", &range_max)); + string vocab_file; + OP_REQUIRES_OK(context, context->GetAttr("vocab_file", &vocab_file)); + std::vector unigrams; + OP_REQUIRES_OK(context, context->GetAttr("unigrams", &unigrams)); + OP_REQUIRES( + context, !vocab_file.empty() || !unigrams.empty(), + errors::InvalidArgument("Must provide either vocab_file or unigrams.")); + OP_REQUIRES(context, vocab_file.empty() || unigrams.empty(), + errors::InvalidArgument( + "Must only provide one of vocab_file and unigrams.")); + float distortion; + OP_REQUIRES_OK(context, context->GetAttr("distortion", &distortion)); + int64 num_reserved_ids; + OP_REQUIRES_OK(context, + context->GetAttr("num_reserved_ids", &num_reserved_ids)); + int64 num_shards; + OP_REQUIRES_OK(context, context->GetAttr("num_shards", &num_shards)); + int64 shard; + OP_REQUIRES_OK(context, context->GetAttr("shard", &shard)); + + if (!vocab_file.empty()) { + set_sampler(new FixedUnigramSampler(context->env(), range_max, vocab_file, + distortion, num_reserved_ids, + num_shards, shard)); + } else { + set_sampler(new FixedUnigramSampler(range_max, unigrams, distortion, + num_reserved_ids, num_shards, shard)); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("FixedUnigramCandidateSampler").Device(DEVICE_CPU), + FixedUnigramCandidateSamplerOp); + +class ComputeAccidentalHitsOp : public OpKernel { + public: + explicit ComputeAccidentalHitsOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("num_true", &num_true_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& in_true_candidates = context->input(0); + TensorShape in_true_candidates_shape = in_true_candidates.shape(); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(in_true_candidates_shape) && + in_true_candidates_shape.dim_size(1) == num_true_, + errors::InvalidArgument( + "true_candidates must be a batch_size * num_true matrix")); + + const int64 batch_size = in_true_candidates_shape.dim_size(0); + + const Tensor& in_sampled_candidates = context->input(1); + OP_REQUIRES(context, + TensorShapeUtils::IsVector(in_sampled_candidates.shape()), + errors::InvalidArgument( + "sampled_candidates must be a vector, which is typically " + "an output from CandidateSampler")); + + std::unordered_map sampled_candidate_to_pos; + for (int64 i = 0; i < in_sampled_candidates.dim_size(0); ++i) { + sampled_candidate_to_pos[in_sampled_candidates.vec()(i)] = i; + } + + // Produce output in the same format as UnpackSparseFeatures. + std::vector indices; + std::vector ids; + std::vector weights; + + for (int64 i = 0; i < batch_size; ++i) { + for (int64 j = 0; j < num_true_; ++j) { + const int64 true_candidate = in_true_candidates.matrix()(i, j); + const auto look = sampled_candidate_to_pos.find(true_candidate); + if (look != sampled_candidate_to_pos.end()) { + indices.push_back(i); + ids.push_back(look->second); + weights.push_back(-FLT_MAX); + } + } + } + + Tensor* out_indices = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output( + 0, TensorShape({static_cast(indices.size())}), &out_indices)); + Tensor* out_ids = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output( + 1, TensorShape({static_cast(ids.size())}), &out_ids)); + Tensor* out_weights = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output( + 2, TensorShape({static_cast(weights.size())}), &out_weights)); + + for (size_t i = 0; i < indices.size(); ++i) { + out_indices->vec()(i) = indices[i]; + out_ids->vec()(i) = ids[i]; + out_weights->vec()(i) = weights[i]; + } + } + + private: + int64 num_true_; +}; + +REGISTER_KERNEL_BUILDER(Name("ComputeAccidentalHits").Device(DEVICE_CPU), + ComputeAccidentalHitsOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc new file mode 100644 index 0000000000..779ac57b6a --- /dev/null +++ b/tensorflow/core/kernels/cast_op.cc @@ -0,0 +1,233 @@ +// See docs in ../ops/math_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/cast_op.h" + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +template +void CastMaybeInline(const Device& d, typename TTypes::Flat o, + typename TTypes::ConstFlat i) { + if (o.size() * (sizeof(Tin) + sizeof(Tout)) < 131072) { + // Small cast on a CPU: do inline + o = i.template cast(); + } else { + o.device(d) = i.template cast(); + } +} + +template +struct CastFunctor { + void operator()(const CPUDevice& d, typename TTypes::Flat o, + typename TTypes::ConstFlat i) { + CastMaybeInline(d, o, i); + } +}; + +} // namespace functor + +#define CAST_CASE(DEVICE, IN, OUT) \ + if (DataTypeToEnum::value == src_dtype_ && \ + DataTypeToEnum::value == dst_dtype_) { \ + work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \ + functor::CastFunctor func; \ + func(ctx->eigen_device(), out->flat(), inp.flat()); \ + }; \ + return Status::OK(); \ + } + +class CastOpBase : public OpKernel { + public: + explicit CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& inp = ctx->input(0); + if (work_ == nullptr) { + ctx->set_output(0, inp); + } else { + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); + work_(ctx, inp, out); + } + } + + protected: + DataType src_dtype_; + DataType dst_dtype_; + std::function work_ = nullptr; + + virtual Status Prepare() = 0; + Status Unimplemented() { + return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ", + DataTypeString(dst_dtype_), + " is not supported"); + } + + TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase); +}; + +class CpuCastOp : public CastOpBase { + public: + explicit CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { + OP_REQUIRES_OK(ctx, Prepare()); + } + + protected: + Status Prepare() override { + if (src_dtype_ == dst_dtype_) { + work_ = nullptr; // Identity + return Status::OK(); + } + CAST_CASE(CPUDevice, bool, float); + CAST_CASE(CPUDevice, bool, int32); + CAST_CASE(CPUDevice, bool, double); + CAST_CASE(CPUDevice, double, float); + CAST_CASE(CPUDevice, double, int32); + CAST_CASE(CPUDevice, double, int64); + CAST_CASE(CPUDevice, float, double); + CAST_CASE(CPUDevice, float, uint8); + CAST_CASE(CPUDevice, float, int32); + CAST_CASE(CPUDevice, float, int64); + CAST_CASE(CPUDevice, int32, double); + CAST_CASE(CPUDevice, int32, float); + CAST_CASE(CPUDevice, int32, uint8); + CAST_CASE(CPUDevice, int32, int64); + CAST_CASE(CPUDevice, int64, double); + CAST_CASE(CPUDevice, int64, float); + CAST_CASE(CPUDevice, int64, int32); + CAST_CASE(CPUDevice, uint8, float); + CAST_CASE(CPUDevice, uint8, int32); + CAST_CASE(CPUDevice, uint8, int64); + CAST_CASE(CPUDevice, uint8, double); + if (src_dtype_ == DT_BFLOAT16 && dst_dtype_ == DT_FLOAT) { + work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { + int64 N = out->NumElements(); + auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); + int num_threads = + std::min(std::min(4, worker_threads->num_threads), N / 4096); + if (num_threads < 1) { + BFloat16ToFloat(inp.flat().data(), + out->flat().data(), N); + } else { + auto work = [&inp, &out](int64 start, int64 end) { + BFloat16ToFloat(inp.flat().data() + start, + out->flat().data() + start, end - start); + }; + Shard(num_threads, worker_threads->workers, N, 100, work); + } + }; + return Status::OK(); + } + if (src_dtype_ == DT_FLOAT && dst_dtype_ == DT_BFLOAT16) { + work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { + int64 N = out->NumElements(); + auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); + int num_threads = + std::min(std::min(4, worker_threads->num_threads), N / 4096); + if (num_threads < 1) { + FloatToBFloat16(inp.flat().data(), + out->flat().data(), N); + } else { + auto work = [&inp, &out](int64 start, int64 end) { + FloatToBFloat16(inp.flat().data() + start, + out->flat().data() + start, end - start); + }; + Shard(num_threads, worker_threads->workers, N, 100, work); + } + }; + return Status::OK(); + } + return Unimplemented(); + } +}; + +class GpuCastOp : public CastOpBase { + public: + explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { + OP_REQUIRES_OK(ctx, Prepare()); + } + + protected: + Status Prepare() override { + if (src_dtype_ == dst_dtype_) { + work_ = nullptr; // Identity + return Status::OK(); + } + CAST_CASE(GPUDevice, bfloat16, float); + CAST_CASE(GPUDevice, bool, float); + CAST_CASE(GPUDevice, double, float); + CAST_CASE(GPUDevice, double, int64); + CAST_CASE(GPUDevice, float, bfloat16); + CAST_CASE(GPUDevice, float, double); + CAST_CASE(GPUDevice, float, int64); + CAST_CASE(GPUDevice, int64, double); + CAST_CASE(GPUDevice, int64, float); + CAST_CASE(GPUDevice, uint8, float); + CAST_CASE(GPUDevice, float, uint8); + CAST_CASE(GPUDevice, bool, int32); + CAST_CASE(GPUDevice, double, int32); + CAST_CASE(GPUDevice, float, int32); + CAST_CASE(GPUDevice, int32, double); + CAST_CASE(GPUDevice, int32, float); + CAST_CASE(GPUDevice, int32, int64); + CAST_CASE(GPUDevice, int64, int32); + return Unimplemented(); + } +}; + +#undef CAST_CASE + +REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp); + +#if GOOGLE_CUDA +#define REGISTER_CAST_GPU(srctype, dsttype) \ + REGISTER_KERNEL_BUILDER(Name("Cast") \ + .TypeConstraint("SrcT") \ + .TypeConstraint("DstT") \ + .Device(DEVICE_GPU), \ + GpuCastOp); +REGISTER_CAST_GPU(bfloat16, float); +REGISTER_CAST_GPU(bool, float); +REGISTER_CAST_GPU(double, float); +REGISTER_CAST_GPU(double, int64); +REGISTER_CAST_GPU(float, bfloat16); +REGISTER_CAST_GPU(float, double); +REGISTER_CAST_GPU(float, int64); +REGISTER_CAST_GPU(int64, double); +REGISTER_CAST_GPU(int64, float); +REGISTER_CAST_GPU(uint8, float); +REGISTER_CAST_GPU(float, uint8); +REGISTER_CAST_GPU(bool, int32); +REGISTER_CAST_GPU(double, int32); +REGISTER_CAST_GPU(float, int32); +REGISTER_CAST_GPU(int32, double); +REGISTER_CAST_GPU(int32, float); +REGISTER_CAST_GPU(int32, int64); +REGISTER_CAST_GPU(int64, int32); +#undef REGISTER_CAST_GPU +#endif // GOOGLE_CUDA + +// HostCast differs from Cast in that its input and output are in host memory. +REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp); +REGISTER_KERNEL_BUILDER( + Name("_HostCast").Device(DEVICE_GPU).HostMemory("x").HostMemory("y"), + CpuCastOp); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h new file mode 100644 index 0000000000..d066206abc --- /dev/null +++ b/tensorflow/core/kernels/cast_op.h @@ -0,0 +1,71 @@ +#ifndef TENSORFLOW_KERNELS_CAST_OP_H_ +#define TENSORFLOW_KERNELS_CAST_OP_H_ + +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/port.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +template +void Cast(const Device& d, typename TTypes::Flat o, + typename TTypes::ConstFlat i) { + o.device(d) = i.template cast(); +} + +template +struct CastFunctor { + void operator()(const Device& d, typename TTypes::Flat o, + typename TTypes::ConstFlat i); +}; + +} // end namespace functor +} // end namespace tensorflow + +namespace Eigen { +namespace internal { + +// Specialized cast op impls for bfloat16. +template <> +struct scalar_cast_op< ::tensorflow::bfloat16, float> { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef float result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()( + const ::tensorflow::bfloat16& a) const { + static_assert(::tensorflow::port::kLittleEndian, ""); + float ret; + uint16_t* p = reinterpret_cast(&ret); + p[0] = 0; + p[1] = a.value; + return ret; + } +}; + +template <> +struct functor_traits > { + enum { Cost = NumTraits::AddCost, PacketAccess = false }; +}; + +template <> +struct scalar_cast_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef ::tensorflow::bfloat16 result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ::tensorflow::bfloat16 operator()( + const float a) const { + static_assert(::tensorflow::port::kLittleEndian, ""); + const uint16_t* p = reinterpret_cast(&a); + return ::tensorflow::bfloat16(p[1]); + } +}; + +template <> +struct functor_traits > { + enum { Cost = NumTraits::AddCost, PacketAccess = false }; +}; + +} // namespace internal +} // namespace Eigen + +#endif // TENSORFLOW_KERNELS_CAST_OP_H_ diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc new file mode 100644 index 0000000000..cd198c752b --- /dev/null +++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc @@ -0,0 +1,45 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/kernels/cast_op.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; + +template +struct CastFunctor { + void operator()(const GPUDevice& d, typename TTypes::Flat o, + typename TTypes::ConstFlat i) { + Cast(d, o, i); + } +}; + +#define DEFINE(O, I) template struct CastFunctor; +DEFINE(float, double); +DEFINE(float, int32); +DEFINE(float, int64); +DEFINE(double, float); +DEFINE(double, int32); +DEFINE(double, int64); +DEFINE(int32, float); +DEFINE(int32, double); +DEFINE(int32, int64); +DEFINE(int64, float); +DEFINE(int64, double); +DEFINE(int64, int32); +DEFINE(int32, bool); +DEFINE(float, bool); +DEFINE(float, uint8); +DEFINE(uint8, float); +DEFINE(float, bfloat16); +DEFINE(bfloat16, float); +#undef DEFINE + +} // end namespace functor +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc new file mode 100644 index 0000000000..f774fbcfe8 --- /dev/null +++ b/tensorflow/core/kernels/cast_op_test.cc @@ -0,0 +1,100 @@ +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" +#include + +namespace tensorflow { + +template +static Graph* Cast(int num) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor data(DataTypeToEnum::value, + TensorShape({64, 64, num / (64 * 64)})); + data.flat().setRandom(); + test::graph::Cast(g, test::graph::Constant(g, data), + DataTypeToEnum::value); + return g; +} + +class CastOpTest : public OpsTestBase { + protected: + void MakeOp(DataType src, DataType dst) { + RequireDefaultOps(); + EXPECT_OK(NodeDefBuilder("cast_op", "Cast") + .Input(FakeInput(DT_INT32)) + .Attr("SrcT", src) + .Attr("DstT", dst) + .Finalize(node_def())); + EXPECT_OK(InitOp()); + } +}; + +TEST_F(CastOpTest, Int32ToUint8) { + MakeOp(DT_INT32, DT_UINT8); + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_UINT8, TensorShape({1, 2, 2, 1})); + test::FillValues(&expected, {1, 2, 3, 4}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +static void BM_cpu_float_int64(int iters, int num) { + testing::ItemsProcessed(static_cast(iters) * num); + testing::BytesProcessed(static_cast(iters) * num * + (sizeof(float) + sizeof(int64))); + testing::UseRealTime(); + test::Benchmark("cpu", Cast(num)).Run(iters); +} +BENCHMARK(BM_cpu_float_int64)->Arg(64 << 10)->Arg(32 << 20); + +static void BM_gpu_float_int64(int iters, int num) { + testing::ItemsProcessed(static_cast(iters) * num); + testing::BytesProcessed(static_cast(iters) * num * + (sizeof(float) + sizeof(int64))); + testing::UseRealTime(); + test::Benchmark("gpu", Cast(num)).Run(iters); +} +BENCHMARK(BM_gpu_float_int64)->Arg(64 << 10)->Arg(32 << 20); + +static void BM_cpu_bool_float(int iters, int num) { + testing::ItemsProcessed(static_cast(iters) * num); + testing::BytesProcessed(static_cast(iters) * num * + (sizeof(bool) + sizeof(float))); + testing::UseRealTime(); + test::Benchmark("cpu", Cast(num)).Run(iters); +} +BENCHMARK(BM_cpu_bool_float)->Arg(64 << 10)->Arg(32 << 20); + +static void BM_gpu_bool_float(int iters, int num) { + testing::ItemsProcessed(static_cast(iters) * num); + testing::BytesProcessed(static_cast(iters) * num * + (sizeof(bool) + sizeof(float))); + testing::UseRealTime(); + test::Benchmark("gpu", Cast(num)).Run(iters); +} +BENCHMARK(BM_gpu_bool_float)->Arg(64 << 10)->Arg(32 << 20); + +static void BM_cpu_float_bfloat16(int iters, int num) { + testing::ItemsProcessed(static_cast(iters) * num); + testing::BytesProcessed(static_cast(iters) * num * + (sizeof(float) + sizeof(bfloat16))); + testing::UseRealTime(); + test::Benchmark("cpu", Cast(num)).Run(iters); +} +BENCHMARK(BM_cpu_float_bfloat16)->Arg(64 << 10)->Arg(32 << 20); + +static void BM_cpu_bfloat16_float(int iters, int num) { + testing::ItemsProcessed(static_cast(iters) * num); + testing::BytesProcessed(static_cast(iters) * num * + (sizeof(float) + sizeof(bfloat16))); + testing::UseRealTime(); + test::Benchmark("cpu", Cast(num)).Run(iters); +} +BENCHMARK(BM_cpu_bfloat16_float)->Arg(64 << 10)->Arg(32 << 20); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/check_numerics_op.cc b/tensorflow/core/kernels/check_numerics_op.cc new file mode 100644 index 0000000000..65487a303c --- /dev/null +++ b/tensorflow/core/kernels/check_numerics_op.cc @@ -0,0 +1,190 @@ +// See docs in ../ops/array_ops.cc. + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/stream_executor/stream.h" +#endif // GOOGLE_CUDA +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +#if GOOGLE_CUDA +template +struct CheckNumericsLaunch { + void Run(const GPUDevice& d, const T* data, int size, + int abnormal_detected[2]); +}; +#endif + +namespace { + +template +class CheckNumericsOp; + +// Partial specialization for CPU +template +class CheckNumericsOp : public OpKernel { + public: + explicit CheckNumericsOp(OpKernelConstruction* context) : OpKernel(context) { + // message_ is used as the prefix for the assertion error message. For + // instance, this can be the name of the input op that produced the tensor. + OP_REQUIRES_OK(context, context->GetAttr("message", &message_)); + } + + void Compute(OpKernelContext* context) override { + // pass along the input to the output + context->set_output(0, context->input(0)); + + auto in = context->input(0).flat(); + const T* data = in.data(); + const int size = in.size(); + // Check to see if any element of the tensor is NaN or Inf. + int fp_props = + std::accumulate(data, data + size, 0, [](const int& x, const T& y) { + int prop = std::fpclassify(y); + int result = x; + if (prop == FP_INFINITE) { + result |= kInfBit; + } else if (prop == FP_NAN) { + result |= kNaNBit; + } + return result; + }); + string status; + if ((fp_props & kInfBit) && (fp_props & kNaNBit)) { + status = "Inf and NaN"; + } else { + if (fp_props & kInfBit) { + status = "Inf"; + } + if (fp_props & kNaNBit) { + status = "NaN"; + } + } + if (!status.empty()) { + context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ", + status, " values")); + } + } + + private: + string message_; + static const int kInfBit = 0x01; + static const int kNaNBit = 0x02; +}; + +#if GOOGLE_CUDA +// Partial specialization for GPU +template +class CheckNumericsOp : public OpKernel { + public: + typedef GPUDevice Device; + + explicit CheckNumericsOp(OpKernelConstruction* context) : OpKernel(context) { + // message_ is used as the prefix for the assertion error message. For + // instance, this can be the name of the input op that produced the tensor. + OP_REQUIRES_OK(context, context->GetAttr("message", &message_)); + } + + void Compute(OpKernelContext* context) override { + // pass along the input to the output + context->set_output(0, context->input(0)); + auto input = context->input(0).flat(); + + // Allocate and initialize the elements to hold the check results + const int abnormal_detected_size = 2; + Tensor abnormal_detected; + OP_REQUIRES_OK(context, context->allocate_temp( + DT_INT32, TensorShape({abnormal_detected_size}), + &abnormal_detected)); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + perftools::gputools::DeviceMemoryBase abnormal_detected_ptr( + abnormal_detected.flat().data(), + abnormal_detected.flat().size()); + stream->ThenMemset32(&abnormal_detected_ptr, 0, + abnormal_detected.flat().size() * sizeof(int)); + + // Call the Cuda kernels for the numerical checks + const Device& d = context->eigen_device(); + CheckNumericsLaunch().Run(d, input.data(), input.size(), + abnormal_detected.flat().data()); + + // Copy the results from device to host + AllocatorAttributes attr; + attr.set_on_host(true); + attr.set_gpu_compatible(true); + Tensor abnormal_detected_out; + OP_REQUIRES_OK(context, context->allocate_temp( + DT_INT32, TensorShape({abnormal_detected_size}), + &abnormal_detected_out, attr)); + int* abnormal_detected_host = abnormal_detected_out.flat().data(); + stream->ThenMemcpy(abnormal_detected_host, abnormal_detected_ptr, + abnormal_detected_size * sizeof(int)); + stream->BlockHostUntilDone(); + OP_REQUIRES(context, stream->ok(), + errors::Internal("cudaMemcpy from device to host failed")); + + int is_nan = abnormal_detected_host[0]; + int is_inf = abnormal_detected_host[1]; + if (is_nan || is_inf) { + string status; + LOG(ERROR) << "abnormal_detected_host @" << abnormal_detected_host + << " = {" << is_nan << ", " << is_inf << "} " << message_; + + // Results should always be 1 or 0. If we see anything else then + // there has been some GPU memory corruption. + CHECK_GE(is_nan, 0); + CHECK_GE(is_inf, 0); + CHECK_LE(is_nan, 1); + CHECK_LE(is_inf, 1); + + if (is_nan && is_inf) { + status = "Inf and NaN"; + } else if (is_nan) { + status = "NaN"; + } else if (is_inf) { + status = "Inf"; + } + context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ", + status, " values")); + } + } + + private: + string message_; +}; +#endif // GOOGLE_CUDA + +} // namespace + +REGISTER_KERNEL_BUILDER(Name("CheckNumerics") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + CheckNumericsOp); +REGISTER_KERNEL_BUILDER(Name("CheckNumerics") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + CheckNumericsOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("CheckNumerics") + .Device(DEVICE_GPU) + .TypeConstraint("T"), + CheckNumericsOp); +REGISTER_KERNEL_BUILDER(Name("CheckNumerics") + .Device(DEVICE_GPU) + .TypeConstraint("T"), + CheckNumericsOp); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc b/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc new file mode 100644 index 0000000000..cb84f98731 --- /dev/null +++ b/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc @@ -0,0 +1,62 @@ +#if GOOGLE_CUDA +#define EIGEN_USE_GPU + +#include +#include + +#include +#include + +#include "tensorflow/core/platform/port.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +namespace { + +typedef Eigen::GpuDevice GPUDevice; + +// A Cuda kernel to check if each element is Inf or Nan. If any exists, the +// relevant elements in abnormal_detected will be set +template +__global__ void CheckNumericsKernel(const T *data, int size, + int abnormal_detected[2]) { + const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int32 total_thread_count = gridDim.x * blockDim.x; + + int32 offset = thread_id; + + while (offset < size) { + if (isnan(data[offset])) { + abnormal_detected[0] = 1; + } + if (isinf(data[offset])) { + abnormal_detected[1] = 1; + } + offset += total_thread_count; + } +} + +} // namespace + +// A simple launch pad to launch the Cuda kernels that checks the numerical +// abnormality in the given array +template +struct CheckNumericsLaunch { + void Run(const GPUDevice &d, const T *data, int size, + int abnormal_detected[2]) { + const int32 block_size = d.maxCudaThreadsPerBlock(); + const int32 num_blocks = + (d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) / + block_size; + + CheckNumericsKernel<<>>( + data, size, abnormal_detected); + } +}; + +template struct CheckNumericsLaunch; +template struct CheckNumericsLaunch; + +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc new file mode 100644 index 0000000000..12632fb248 --- /dev/null +++ b/tensorflow/core/kernels/cholesky_op.cc @@ -0,0 +1,71 @@ +// See docs in ../ops/linalg_ops.cc. +// TODO(konstantinos): Enable complex inputs. This will require additional tests +// and OP_REQUIRES. + +#include + +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/Eigen/Cholesky" + +namespace tensorflow { + +template +class CholeskyOp : public LinearAlgebraOp { + public: + explicit CholeskyOp(OpKernelConstruction* context) + : LinearAlgebraOp(context) {} + + TensorShape GetOutputMatrixShape( + const TensorShape& input_matrix_shape) override { + return input_matrix_shape; + } + + int64 GetCostPerUnit(const TensorShape& input_matrix_shape) override { + const int64 rows = input_matrix_shape.dim_size(0); + if (rows > (1LL << 20)) { + // A big number to cap the cost in case overflow. + return kint32max; + } else { + return rows * rows * rows; + } + } + + using typename LinearAlgebraOp::MatrixMap; + using + typename LinearAlgebraOp::ConstMatrixMap; + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& input, + MatrixMap* output) override { + OP_REQUIRES(context, input.rows() == input.cols(), + errors::InvalidArgument("Input matrix must be square.")); + if (input.rows() == 0) { + // If X is an empty matrix (0 rows, 0 col), X * X' == X. + // Therefore, we return X. + return; + } + // Perform the actual LL^T Cholesky decomposition. This will only use + // the lower triangular part of data_in by default. The upper triangular + // part of the matrix will not be read. + Eigen::LLT> llt_decomposition(input); + + // Output the lower triangular in a dense form. + *output = llt_decomposition.matrixL(); + + OP_REQUIRES(context, llt_decomposition.info() == Eigen::Success, + errors::InvalidArgument("LLT decomposition was not successful. " + "The input might not be valid.")); + } +}; + +REGISTER_LINALG_OP("Cholesky", (CholeskyOp), float); +REGISTER_LINALG_OP("Cholesky", (CholeskyOp), double); +REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp), float); +REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp), double); +} // namespace tensorflow diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc new file mode 100644 index 0000000000..b68fcec515 --- /dev/null +++ b/tensorflow/core/kernels/concat_op.cc @@ -0,0 +1,153 @@ +// See docs in ../ops/array_ops.cc. + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/concat_op.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// -------------------------------------------------------------------------- +template +class ConcatOp : public OpKernel { + public: + typedef std::vector::ConstMatrix>> + ConstMatrixVector; + + explicit ConcatOp(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* c) override { + const Tensor* concat_dim_tensor; + OP_REQUIRES_OK(c, c->input("concat_dim", &concat_dim_tensor)); + OP_REQUIRES( + c, TensorShapeUtils::IsLegacyScalar(concat_dim_tensor->shape()), + errors::InvalidArgument( + "Concat dim tensor should be a scalar integer, but got shape ", + concat_dim_tensor->shape().DebugString())); + const int32 concat_dim = concat_dim_tensor->scalar()(); + OpInputList values; + OP_REQUIRES_OK(c, c->input_list("values", &values)); + const int N = values.size(); + const int input_dims = values[0].dims(); + const TensorShape& input_shape = values[0].shape(); + OP_REQUIRES( + c, (0 <= concat_dim && concat_dim < input_dims) || + (kAllowLegacyScalars && concat_dim == 0), + errors::InvalidArgument( + "ConcatOp : Expected concatenating dimensions in the range [", 0, + ", ", input_dims, "), but got ", concat_dim)); + + // Note that we reduce the concat of n-dimensional tensors into a two + // dimensional concat. Assuming the dimensions of any input/output + // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along + // the dimension indicated with size y0, we flatten it to {x, y}, where y = + // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1). + ConstMatrixVector inputs_flat; + inputs_flat.reserve(N); + int64 inputs_flat_dim0 = 1; + for (int d = 0; d < concat_dim; ++d) { + inputs_flat_dim0 *= input_shape.dim_size(d); + } + int output_concat_dim = 0; + const bool input_is_scalar = TensorShapeUtils::IsLegacyScalar(input_shape); + for (int i = 0; i < N; ++i) { + const auto in = values[i]; + const bool in_is_scalar = TensorShapeUtils::IsLegacyScalar(in.shape()); + OP_REQUIRES( + c, in.dims() == input_dims || (input_is_scalar && in_is_scalar), + errors::InvalidArgument( + "ConcatOp : Ranks of all input tensors should match: shape[0] = ", + input_shape.ShortDebugString(), " vs. shape[", i, "] = ", + in.shape().ShortDebugString())); + for (int j = 0; j < input_dims; ++j) { + if (j == concat_dim) { + continue; + } + OP_REQUIRES( + c, in.dim_size(j) == input_shape.dim_size(j), + errors::InvalidArgument( + "ConcatOp : Dimensions of inputs should match: shape[0] = ", + input_shape.ShortDebugString(), " vs. shape[", i, "] = ", + in.shape().ShortDebugString())); + } + if (in.NumElements() > 0) { + int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; + inputs_flat.emplace_back(new typename TTypes::ConstMatrix( + in.shaped({inputs_flat_dim0, inputs_flat_dim1}))); + } + // TODO(irving): Remove check once !kAllowLegacyScalars + output_concat_dim += in.dims() > 0 ? in.dim_size(concat_dim) : 1; + } + + TensorShape output_shape(input_shape); + // TODO(irving): Remove rank 0 case once !kAllowLegacyScalars + if (output_shape.dims() == 0) { + output_shape.AddDim(output_concat_dim); + } else { + output_shape.set_dim(concat_dim, output_concat_dim); + } + Tensor* output = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); + if (output->NumElements() > 0) { + int64 output_dim1 = output->NumElements() / inputs_flat_dim0; + auto output_flat = output->shaped({inputs_flat_dim0, output_dim1}); + if (std::is_same::value) { + ConcatGPU(c->eigen_gpu_device(), inputs_flat, &output_flat); + } else { + ConcatCPU(c->device(), inputs_flat, &output_flat); + } + } + } +}; + +#define REGISTER_CONCAT(type) \ + REGISTER_KERNEL_BUILDER(Name("Concat") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("concat_dim"), \ + ConcatOp) + +TF_CALL_ALL_TYPES(REGISTER_CONCAT); +REGISTER_CONCAT(quint8); +REGISTER_CONCAT(qint8); +REGISTER_CONCAT(qint32); +REGISTER_CONCAT(bfloat16); + +#undef REGISTER_CONCAT + +#if GOOGLE_CUDA + +#define REGISTER_GPU(type) \ + REGISTER_KERNEL_BUILDER(Name("Concat") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("concat_dim"), \ + ConcatOp) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); +#undef REGISTER_GPU + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Concat") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("concat_dim") + .HostMemory("values") + .HostMemory("output"), + ConcatOp); + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/concat_op.h b/tensorflow/core/kernels/concat_op.h new file mode 100644 index 0000000000..664e55080d --- /dev/null +++ b/tensorflow/core/kernels/concat_op.h @@ -0,0 +1,27 @@ +#ifndef TENSORFLOW_KERNELS_CONCAT_OP_H_ +#define TENSORFLOW_KERNELS_CONCAT_OP_H_ + +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/device_base.h" + +namespace tensorflow { + +// Assumes all inputs are nonempty +template +void ConcatCPU(DeviceBase* d, + const std::vector< + std::unique_ptr::ConstMatrix>>& inputs, + typename TTypes::Matrix* output); + +// Assumes all inputs are nonempty +template +void ConcatGPU(const Eigen::GpuDevice& d, + const std::vector< + std::unique_ptr::ConstMatrix>>& inputs, + typename TTypes::Matrix* output); + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_CONCAT_OP_H_ diff --git a/tensorflow/core/kernels/concat_op_cpu.cc b/tensorflow/core/kernels/concat_op_cpu.cc new file mode 100644 index 0000000000..679a53721c --- /dev/null +++ b/tensorflow/core/kernels/concat_op_cpu.cc @@ -0,0 +1,122 @@ +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/concat_op.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +template +static inline void Copy(T* dst, const T* src, int n) { + if (DataTypeCanUseMemcpy(DataTypeToEnum::v())) { + memcpy(dst, src, n * sizeof(T)); + } else { + for (int k = 0; k < n; ++k) { + *dst++ = *src++; + } + } +} + +template +void ConcatCPU(DeviceBase* d, + const std::vector< + std::unique_ptr::ConstMatrix>>& inputs, + typename TTypes::Matrix* output) { + int num_inputs = inputs.size(); + std::vector sizes; + sizes.reserve(num_inputs); + int row_size = 0; + for (int j = 0; j < num_inputs; ++j) { + sizes.push_back(inputs[j]->dimension(1)); + row_size += sizes.back(); + } + + auto worker_threads = d->tensorflow_cpu_worker_threads(); + int num_threads = std::min(std::min(4, worker_threads->num_threads), + output->size() / 4096); + // Single threaded mode. + if (num_threads == 0) { + T* out = &(*output)(0, 0); + std::vector inp; + inp.reserve(num_inputs); + for (int j = 0; j < num_inputs; ++j) { + inp.push_back(&(*inputs[j])(0, 0)); + } + const int dim0 = output->dimension(0); + for (int i = 0; i < dim0; ++i) { + for (int j = 0; j < num_inputs; ++j) { + auto size = sizes[j]; + Copy(out, inp[j], size); + out += size; + inp[j] += size; + } + } + return; + } + + // Sharded mode. + auto work = [&row_size, &sizes, &inputs, &output, &num_inputs](int64 start, + int64 end) { + int64 skipped_rows = start / row_size; + T* out = output->data() + skipped_rows * row_size; + T* out_start = output->data() + start; + T* out_end = output->data() + end; + + // Handle partial row at start + if (out < out_start) { + for (int j = 0; j < num_inputs; ++j) { + ptrdiff_t size = sizes[j]; + ptrdiff_t offset = out_start - out; + if (size <= offset) { + out += size; + continue; + } + const T* inp = &(*inputs[j])(skipped_rows, 0); + if (offset > 0) { + out += offset; + inp += offset; + size -= offset; + } + size = std::min(size, out_end - out); + if (size <= 0) break; + Copy(out, inp, size); + out += size; + } + ++skipped_rows; + } + if (out == out_end) return; + CHECK(out >= out_start); + CHECK(out < out_end); + + // Copy remaining data. + std::vector inp; + inp.reserve(num_inputs); + for (int j = 0; j < num_inputs; ++j) { + inp.push_back(&(*inputs[j])(skipped_rows, 0)); + } + const int dim0 = output->dimension(0); + for (int i = skipped_rows; i < dim0; ++i) { + for (int j = 0; j < num_inputs; ++j) { + ptrdiff_t size = std::min(sizes[j], out_end - out); + Copy(out, inp[j], size); + out += size; + inp[j] += size; + if (out == out_end) return; + } + } + }; + Shard(num_threads, worker_threads->workers, output->size(), 100, work); +} + +#define REGISTER(T) \ + template void ConcatCPU( \ + DeviceBase*, \ + const std::vector::ConstMatrix>>&, \ + typename TTypes::Matrix* output); +TF_CALL_ALL_TYPES(REGISTER) +REGISTER(quint8) +REGISTER(qint8) +REGISTER(qint32) +REGISTER(bfloat16) + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/concat_op_gpu.cu.cc b/tensorflow/core/kernels/concat_op_gpu.cu.cc new file mode 100644 index 0000000000..d8ce6bd85d --- /dev/null +++ b/tensorflow/core/kernels/concat_op_gpu.cu.cc @@ -0,0 +1,41 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include + +#include + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +template +void ConcatGPU(const GPUDevice& d, + const std::vector< + std::unique_ptr::ConstMatrix>>& inputs, + typename TTypes::Matrix* output) { + Eigen::array offset(0, 0); + for (int i = 0; i < inputs.size(); ++i) { + Eigen::array size = inputs[i]->dimensions(); + output->slice(offset, size).device(d) = *inputs[i]; + offset[1] += size[1]; + } +} + +#define REGISTER_GPU(T) \ + template void ConcatGPU( \ + const GPUDevice& d, \ + const std::vector::ConstMatrix>>& \ + inputs, \ + typename TTypes::Matrix* output); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); +#undef REGISTER_GPU + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/concat_op_test.cc b/tensorflow/core/kernels/concat_op_test.cc new file mode 100644 index 0000000000..4ccc5b5b19 --- /dev/null +++ b/tensorflow/core/kernels/concat_op_test.cc @@ -0,0 +1,240 @@ +#include +#include +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +// For the benchmark, we set up two 2-dimensional tensors, each kDim1 x 'dim' +// in size, and concat them together along "concat_dimension" +template +static void ConcatHelper(int iters, int concat_dimension, int dim2) { + testing::StopTiming(); + RequireDefaultOps(); + Graph* g = new Graph(OpRegistry::Global()); + + DataType dt = DataTypeToEnum::v(); + const int kDim1 = 100; + Tensor concat_dim(DT_INT32, TensorShape({})); + concat_dim.scalar()() = concat_dimension; + Tensor in0(dt, TensorShape({kDim1, dim2})); + in0.flat().setRandom(); + Tensor in1(dt, TensorShape({kDim1, dim2})); + in1.flat().setRandom(); + + Node* node; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), "Concat") + .Input(test::graph::Constant(g, concat_dim)) + .Input({test::graph::Constant(g, in0), test::graph::Constant(g, in1)}) + .Attr("N", 2) + .Attr("T", dt) + .Finalize(g, &node)); + + testing::BytesProcessed(static_cast(iters) * + ((kDim1 * dim2) + (kDim1 * dim2)) * sizeof(T)); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); + testing::UseRealTime(); +} + +static void BM_ConcatDim0Float(int iters, int dim2) { + ConcatHelper(iters, 0, dim2); +} + +static void BM_ConcatDim1Float(int iters, int dim2) { + ConcatHelper(iters, 1, dim2); +} + +BENCHMARK(BM_ConcatDim0Float)->Arg(1000)->Arg(100000)->Arg(1000000); +BENCHMARK(BM_ConcatDim1Float)->Arg(1000)->Arg(100000)->Arg(1000000); + +static void BM_ConcatDim1int16(int iters, int dim2) { + ConcatHelper(iters, 1, dim2); +} +static void BM_ConcatDim1bfloat16(int iters, int dim2) { + ConcatHelper(iters, 1, dim2); +} + +BENCHMARK(BM_ConcatDim1int16)->Arg(1000)->Arg(100000)->Arg(1000000); +BENCHMARK(BM_ConcatDim1bfloat16)->Arg(1000)->Arg(100000)->Arg(1000000); + +template +static void ConcatManyHelper(int iters, int concat_dimension, int dim2) { + testing::StopTiming(); + RequireDefaultOps(); + Graph* g = new Graph(OpRegistry::Global()); + + DataType dt = DataTypeToEnum::v(); + const int kDim1 = 40000; + const int kNumInputs = 64; + Tensor concat_dim(DT_INT32, TensorShape({})); + concat_dim.scalar()() = concat_dimension; + std::vector inputs; + inputs.reserve(kNumInputs); + for (int i = 0; i < kNumInputs; ++i) { + Tensor in(dt, TensorShape({kDim1, dim2})); + in.flat().setRandom(); + inputs.push_back(test::graph::Constant(g, in)); + } + + Node* node; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Concat") + .Input(test::graph::Constant(g, concat_dim)) + .Input(inputs) + .Attr("N", 64) + .Attr("T", dt) + .Finalize(g, &node)); + testing::BytesProcessed(static_cast(iters) * kDim1 * dim2 * + kNumInputs * sizeof(T)); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); + testing::UseRealTime(); +} + +static void BM_ConcatManyDim1bfloat16(int iters, int dim2) { + ConcatManyHelper(iters, 1, dim2); +} + +BENCHMARK(BM_ConcatManyDim1bfloat16)->Arg(18)->Arg(34)->Arg(60); + +static void MemcpyAlternativeHelper(int iters, int concat_dimension, int dim2) { + testing::StopTiming(); + + const int kDim1 = 100; + std::vector data1(kDim1 * dim2, 1.0f); + std::vector data2(kDim1 * dim2, 2.0f); + + testing::BytesProcessed(static_cast(iters) * + ((kDim1 * dim2) + (kDim1 * dim2)) * sizeof(float)); + testing::StartTiming(); + while (--iters > 0) { + const int n0 = data1.size(); + const int n1 = data2.size(); + float* result = new float[n0 + n1]; + memcpy(&result[0], &data1[0], n0 * sizeof(float)); + memcpy(&result[n0], &data2[0], n1 * sizeof(float)); + delete[] result; + } +} + +static void BM_MemcpyAlternativeDim0(int iters, int dim2) { + MemcpyAlternativeHelper(iters, 0, dim2); +} +static void BM_MemcpyAlternativeDim1(int iters, int dim2) { + MemcpyAlternativeHelper(iters, 1, dim2); +} + +BENCHMARK(BM_MemcpyAlternativeDim0)->Arg(1000)->Arg(100000)->Arg(1000000); +BENCHMARK(BM_MemcpyAlternativeDim1)->Arg(1000)->Arg(100000)->Arg(1000000); + +typedef Eigen::TensorMap, + Eigen::Unaligned> EigenMap; +static void MemcpyManyAlternative1(int iters, int dim2) { + testing::StopTiming(); + + const int kDim1 = 40000; + const int kNumCopies = 64; + const int size = kDim1 * dim2 * kNumCopies; + bfloat16* data = new bfloat16[size]; + EigenMap map(data, size); + map.setRandom(); + + testing::BytesProcessed(static_cast(iters) * kDim1 * dim2 * + kNumCopies * sizeof(bfloat16)); + testing::StartTiming(); + while (iters-- > 0) { + std::vector inputs(kNumCopies); + for (int i = 0; i < kNumCopies; ++i) { + inputs[i] = &data[i * kDim1 * dim2]; + } + bfloat16* result = new bfloat16[size]; + for (int j = 0; j < kNumCopies; ++j) { + bfloat16* output = &result[j * dim2]; + for (int i = 0; i < kDim1; ++i) { + if (i + 1 < kDim1) { + port::prefetch(inputs[j] + dim2); + } + memcpy(output, inputs[j], dim2 * sizeof(bfloat16)); + inputs[j] += dim2; + output += dim2 * kNumCopies; + } + } + delete[] result; + } + delete[] data; +} + +static void MemcpyManyAlternative2(int iters, int dim2) { + testing::StopTiming(); + + const int kDim1 = 40000; + const int kNumCopies = 64; + const int size = kDim1 * dim2 * kNumCopies; + bfloat16* data = new bfloat16[size]; + EigenMap map(data, size); + map.setRandom(); + + testing::BytesProcessed(static_cast(iters) * kDim1 * dim2 * + kNumCopies * sizeof(bfloat16)); + testing::StartTiming(); + std::vector inputs(kNumCopies); + while (--iters > 0) { + bfloat16* result = new bfloat16[size]; + for (int i = 0; i < kNumCopies; ++i) { + inputs[i] = &data[i * kDim1 * dim2]; + } + bfloat16* output = result; + for (int i = 0; i < kDim1; ++i) { + for (int j = 0; j < kNumCopies; ++j) { + if (j + 1 < kNumCopies) { + port::prefetch(inputs[j + 1]); + } + memcpy(output, inputs[j], dim2 * sizeof(bfloat16)); + inputs[j] += dim2; + output += dim2; + } + } + delete[] result; + } + delete[] data; +} + +BENCHMARK(MemcpyManyAlternative1) + ->Arg(16) + ->Arg(17) + ->Arg(18) + ->Arg(32) + ->Arg(33) + ->Arg(34) + ->Arg(60) + ->Arg(64) + ->Arg(65); + +BENCHMARK(MemcpyManyAlternative2) + ->Arg(16) + ->Arg(17) + ->Arg(18) + ->Arg(32) + ->Arg(33) + ->Arg(34) + ->Arg(60) + ->Arg(64) + ->Arg(65); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc new file mode 100644 index 0000000000..281bafd3df --- /dev/null +++ b/tensorflow/core/kernels/constant_op.cc @@ -0,0 +1,249 @@ +// See docs in ../ops/array_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/constant_op.h" + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +ConstantOp::ConstantOp(OpKernelConstruction* ctx) + : OpKernel(ctx), tensor_(ctx->output_type(0)) { + const TensorProto* proto = nullptr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); + OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto( + *proto, AllocatorAttributes(), &tensor_)); + OP_REQUIRES( + ctx, ctx->output_type(0) == tensor_.dtype(), + errors::InvalidArgument("Type mismatch between value (", + DataTypeString(tensor_.dtype()), ") and dtype (", + DataTypeString(ctx->output_type(0)), ")")); +} + +void ConstantOp::Compute(OpKernelContext* ctx) { ctx->set_output(0, tensor_); } + +ConstantOp::~ConstantOp() {} + +REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp); + +#if GOOGLE_CUDA +#define REGISTER_KERNEL(D, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Const").Device(DEVICE_##D).TypeConstraint("dtype"), \ + ConstantOp); +REGISTER_KERNEL(GPU, float); +REGISTER_KERNEL(GPU, double); +REGISTER_KERNEL(GPU, uint8); +REGISTER_KERNEL(GPU, int8); +REGISTER_KERNEL(GPU, int16); +REGISTER_KERNEL(GPU, int64); +REGISTER_KERNEL(GPU, complex64); +REGISTER_KERNEL(GPU, bool); +// Currently we do not support string constants on GPU +#undef REGISTER_KERNEL +#endif + +// HostConstantOp differs from ConstantOp in that its output is always +// in host memory. +class HostConstantOp : public OpKernel { + public: + explicit HostConstantOp(OpKernelConstruction* ctx) + : OpKernel(ctx), tensor_(ctx->output_type(0)) { + const TensorProto* proto = nullptr; + AllocatorAttributes alloc_attr; + alloc_attr.set_on_host(true); + OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); + OP_REQUIRES_OK( + ctx, ctx->device()->MakeTensorFromProto(*proto, alloc_attr, &tensor_)); + OP_REQUIRES( + ctx, ctx->output_type(0) == tensor_.dtype(), + errors::InvalidArgument( + "Type mismatch between value (", DataTypeString(tensor_.dtype()), + ") and dtype (", DataTypeString(ctx->output_type(0)), ")")); + } + + void Compute(OpKernelContext* ctx) override { ctx->set_output(0, tensor_); } + + bool IsExpensive() override { return false; } + + ~HostConstantOp() override {} + + private: + Tensor tensor_; + TF_DISALLOW_COPY_AND_ASSIGN(HostConstantOp); +}; + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Const") + .Device(DEVICE_GPU) + .HostMemory("output") + .TypeConstraint("dtype"), + HostConstantOp); + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +// Partial specialization of FillFunctor. +template +struct FillFunctor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstScalar in) { + out.device(d) = out.constant(in()); + } +}; + +// Partial specialization of SetZeroFunctor. +template +struct SetZeroFunctor { + void operator()(const CPUDevice& d, typename TTypes::Flat out) { + out.device(d) = out.constant(0); + } +}; + +#define DEFINE_SETZERO_CPU(T) template struct SetZeroFunctor +DEFINE_SETZERO_CPU(float); +DEFINE_SETZERO_CPU(double); +DEFINE_SETZERO_CPU(int32); +DEFINE_SETZERO_CPU(complex64); +#undef DEFINE_SETZERO_CPU + +} // end namespace functor + +template +class FillOp : public OpKernel { + public: + explicit FillOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& Tdims = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(Tdims.shape()), + errors::InvalidArgument("dims must be a vector of int32.")); + const Tensor& Tvalue = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(Tvalue.shape()), + errors::InvalidArgument("value must be a scalar.")); + auto dims = Tdims.flat(); + for (int i = 0; i < dims.size(); i++) { + OP_REQUIRES(context, dims(i) >= 0, + errors::InvalidArgument("dims[", i, "] = ", dims(i), + " must be nonnegative.")); + } + Tensor* out = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output( + 0, TensorShapeUtils::MakeShape( + reinterpret_cast(dims.data()), dims.size()), + &out)); + functor::FillFunctor functor; + functor(context->eigen_device(), out->flat(), + Tvalue.scalar()); + } +}; + +#define REGISTER_KERNEL(D, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("Fill") \ + .Device(DEVICE_##D) \ + .TypeConstraint("T") \ + .HostMemory("dims"), \ + FillOp); + +#define REGISTER_CPU_KERNEL(TYPE) REGISTER_KERNEL(CPU, TYPE) +TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL); +#undef REGISTER_CPU_KERNEL + +#if GOOGLE_CUDA +REGISTER_KERNEL(GPU, float); +REGISTER_KERNEL(GPU, double); +REGISTER_KERNEL(GPU, uint8); +REGISTER_KERNEL(GPU, int8); +REGISTER_KERNEL(GPU, int16); +REGISTER_KERNEL(GPU, int64); +// Currently we do not support filling strings and complex64 on GPU + +#endif // GOOGLE_CUDA + +#undef REGISTER_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Fill") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("dims") + .HostMemory("value") + .HostMemory("output"), + FillOp); + +template +class ZerosLikeOp : public OpKernel { + public: + explicit ZerosLikeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& input = ctx->input(0); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &out)); + Tensor zero(DataTypeToEnum::value, {1}); + zero.scalar().setZero(); + const Tensor& zero_cref = zero; + functor::FillFunctor functor; + functor(ctx->eigen_device(), out->flat(), zero_cref.scalar()); + } +}; + +#define REGISTER_KERNEL(type, dev) \ + REGISTER_KERNEL_BUILDER( \ + Name("ZerosLike").Device(DEVICE_##dev).TypeConstraint("T"), \ + ZerosLikeOp) + +#define REGISTER_CPU(type) REGISTER_KERNEL(type, CPU) +TF_CALL_ALL_TYPES(REGISTER_CPU); +#undef REGISTER_CPU + +#if GOOGLE_CUDA +REGISTER_KERNEL(float, GPU); +REGISTER_KERNEL(double, GPU); +#endif // GOOGLE_CUDA + +#undef REGISTER_KERNEL + +class PlaceholderOp : public OpKernel { + public: + explicit PlaceholderOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &expected_shape_)); + } + + void Compute(OpKernelContext* ctx) override { + if (expected_shape_.dims() > 0) { + OP_REQUIRES(ctx, false, + errors::InvalidArgument( + "You must feed a value for placeholder tensor '", name(), + "' with dtype ", DataTypeString(output_type(0)), + " and shape ", expected_shape_.DebugString())); + } else { + OP_REQUIRES(ctx, false, + errors::InvalidArgument( + "You must feed a value for placeholder tensor '", name(), + "' with dtype ", DataTypeString(output_type(0)))); + } + } + + private: + TensorShape expected_shape_; +}; + +REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE_CPU), PlaceholderOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/constant_op.h b/tensorflow/core/kernels/constant_op.h new file mode 100644 index 0000000000..20a5c9c42f --- /dev/null +++ b/tensorflow/core/kernels/constant_op.h @@ -0,0 +1,25 @@ +#ifndef TENSORFLOW_KERNELS_CONSTANT_OP_H_ +#define TENSORFLOW_KERNELS_CONSTANT_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +// ConstantOp returns a tensor specified by ConstantOpDef. +class ConstantOp : public OpKernel { + public: + explicit ConstantOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; + bool IsExpensive() override { return false; } + ~ConstantOp() override; + + private: + Tensor tensor_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstantOp); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_CONSTANT_OP_H_ diff --git a/tensorflow/core/kernels/constant_op_gpu.cu.cc b/tensorflow/core/kernels/constant_op_gpu.cu.cc new file mode 100644 index 0000000000..64502378bd --- /dev/null +++ b/tensorflow/core/kernels/constant_op_gpu.cu.cc @@ -0,0 +1,89 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/platform/port.h" + +namespace Eigen { +namespace internal { + +template +struct scalar_const_op { + typedef typename packet_traits::type Packet; + + const T* val; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + scalar_const_op(const scalar_const_op& x) + : val(x.val) {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_const_op(const T* v) : val(v) {} + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(Index, + Index = 0) const { + return *val; + } + + template + EIGEN_STRONG_INLINE const Packet packetOp(Index, Index = 0) const { + return internal::pset1(*val); + } +}; + +template +struct functor_traits > { + enum { + Cost = 1, + PacketAccess = packet_traits::Vectorizable, + IsRepeatable = true + }; +}; + +} // end namespace internal +} // end namespace Eigen + +namespace tensorflow { + +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; + +// Partial specialization FillFunctor +template +struct FillFunctor { + void operator()(const GPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstScalar in) { + Eigen::internal::scalar_const_op f(in.data()); + out.device(d) = out.nullaryExpr(f); + } +}; + +#define DEFINE_FILL_GPU(T) template struct FillFunctor +DEFINE_FILL_GPU(float); +DEFINE_FILL_GPU(double); +DEFINE_FILL_GPU(int32); +DEFINE_FILL_GPU(uint8); +DEFINE_FILL_GPU(int16); +DEFINE_FILL_GPU(int8); +DEFINE_FILL_GPU(int64); +#undef DEFINE_FILL_GPU + +// Partial specialization of FillFunctor. +template +struct SetZeroFunctor { + void operator()(const GPUDevice& d, typename TTypes::Flat out) { + out.device(d) = out.constant(0); + } +}; + +#define DEFINE_SETZERO_GPU(T) template struct SetZeroFunctor +DEFINE_SETZERO_GPU(float); +#undef DEFINE_SETZERO_GPU + +} // end namespace functor +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/constant_op_test.cc b/tensorflow/core/kernels/constant_op_test.cc new file mode 100644 index 0000000000..f5a464c07c --- /dev/null +++ b/tensorflow/core/kernels/constant_op_test.cc @@ -0,0 +1,43 @@ +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +// Returns graph containing "num" const nodes. If 'sequential' is +// true, make sure all constants are executed sequentially in the +// graph by adding control dependencies. +static Graph* ManyConsts(int num, bool sequential) { + Graph* g = new Graph(OpRegistry::Global()); + Node* prev = nullptr; + for (int i = 0; i < num; ++i) { + Tensor c(DT_FLOAT, TensorShape({})); + c.scalar()() = i; + Node* curr = test::graph::Constant(g, c); + if (sequential && prev != nullptr) { + g->AddControlEdge(prev, curr); + } + prev = curr; + } + return g; +} + +static void BM_ManyConsts_Parallel(int iters, int num) { + testing::ItemsProcessed(static_cast(iters) * num); + test::Benchmark("cpu", ManyConsts(num, false /* !sequential */)).Run(iters); +} +BENCHMARK(BM_ManyConsts_Parallel)->Range(1, 1 << 10); + +static void BM_ManyConsts_Sequential(int iters, int num) { + testing::ItemsProcessed(static_cast(iters) * num); + test::Benchmark("cpu", ManyConsts(num, true /* sequential */)).Run(iters); +} +BENCHMARK(BM_ManyConsts_Sequential)->Range(1, 1 << 10); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc new file mode 100644 index 0000000000..bc44a7f7cc --- /dev/null +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -0,0 +1,359 @@ +#include "tensorflow/core/kernels/control_flow_ops.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +// A switch op has two inputs and two outputs. It forwards the value of +// Input:0 to the output specified by input:1. Input:1 is a boolean tensor. +// Input:0 is forwarded to output:0 if input:1 is false, otherwise to +// output:1. +class SwitchOp : public OpKernel { + public: + explicit SwitchOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& outputPorts = context->input(1); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(outputPorts.shape()), + errors::InvalidArgument("The second input must be a scalar, " + "but it has shape ", + outputPorts.shape().ShortDebugString())); + + bool pred = outputPorts.scalar()(); + int port = (pred) ? 1 : 0; + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, port); + } else { + context->set_output(port, context->input(0)); + } + } + + bool IsExpensive() override { return false; } + + ~SwitchOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(SwitchOp); +}; + +#define REGISTER_CPU_SWITCH(type) \ + REGISTER_KERNEL_BUILDER(Name("Switch") \ + .Device(DEVICE_CPU) \ + .HostMemory("pred") \ + .TypeConstraint("T"), \ + SwitchOp) + +#define REGISTER_CPU_REF_SWITCH(type) \ + REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ + .Device(DEVICE_CPU) \ + .HostMemory("pred") \ + .TypeConstraint("T"), \ + SwitchOp) + +#define REGISTER_GPU_SWITCH(type) \ + REGISTER_KERNEL_BUILDER(Name("Switch") \ + .Device(DEVICE_GPU) \ + .HostMemory("pred") \ + .TypeConstraint("T"), \ + SwitchOp) + +#define REGISTER_GPU_REF_SWITCH(type) \ + REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ + .Device(DEVICE_GPU) \ + .HostMemory("pred") \ + .TypeConstraint("T"), \ + SwitchOp) + +TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH); +TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SWITCH); +REGISTER_GPU_SWITCH(bool); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_REF_SWITCH); +REGISTER_GPU_REF_SWITCH(int32); +REGISTER_GPU_REF_SWITCH(bool); + +#undef REGISTER_CPU_SWITCH +#undef REGISTER_CPU_REF_SWITCH +#undef REGISTER_GPU_SWITCH +#undef REGISTER_GPU_REF_SWITCH + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Switch") + .Device(DEVICE_GPU) + .HostMemory("data") + .HostMemory("pred") + .HostMemory("output_false") + .HostMemory("output_true") + .TypeConstraint("T"), + SwitchOp); + +class RefSelectOp : public OpKernel { + public: + explicit RefSelectOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("N", &num_ref_inputs_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& index_tensor = context->input(0); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(index_tensor.shape()), + errors::InvalidArgument("Index must be a scalar, " + "but it has shape ", + index_tensor.shape().ShortDebugString())); + + int32 index = index_tensor.scalar()(); + + OP_REQUIRES(context, index >= 0 && index < num_ref_inputs_, + errors::InvalidArgument("Index must be in the range [0, ", + num_ref_inputs_, ") but got ", index)); + context->forward_ref_input_to_ref_output(index + 1, 0); + } + + bool IsExpensive() override { return false; } + + ~RefSelectOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(RefSelectOp); + + private: + int num_ref_inputs_; +}; + +#define REGISTER_CPU_REF_SELECT(type) \ + REGISTER_KERNEL_BUILDER(Name("RefSelect") \ + .Device(DEVICE_CPU) \ + .HostMemory("index") \ + .TypeConstraint("T"), \ + RefSelectOp) +TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT); + +#undef REGISTER_CPU_REF_SWITCH + +// A merge op has n inputs and two outputs. It forwards the value of the +// first input that becomes available to its first output, and the +// index of the first input to its second output. +class MergeOp : public OpKernel { + public: + explicit MergeOp(OpKernelConstruction* context) : OpKernel(context) { + const DataType dt = context->input_type(0); + const int num_in = context->num_inputs(); + OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt), + {dt, DT_INT32})); + } + + void Compute(OpKernelContext* context) override { + bool input_seen = false; + for (int i = 0; i < context->num_inputs(); ++i) { + if (context->has_input(i)) { + if (input_seen) { + context->SetStatus(errors::Internal( + "Merge can not have more than one valid input.")); + return; + } + input_seen = true; + + context->set_output(0, context->input(i)); + Tensor* value_index = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), + &value_index)); + value_index->scalar()() = i; + } + } + } + + bool IsExpensive() override { return false; } + + ~MergeOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(MergeOp); +}; + +REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Merge") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("value_index"), \ + MergeOp); + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); + +#undef REGISTER_GPU_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Merge") + .Device(DEVICE_GPU) + .HostMemory("inputs") + .HostMemory("output") + .HostMemory("value_index") + .TypeConstraint("T"), + MergeOp); + +// An enter op has one input and one output. It creates or finds +// the child frame that is uniquely identified by the frame_name, +// and makes its input available to the child frame. +class EnterOp : public OpKernel { + public: + explicit EnterOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); + } + } + + bool IsExpensive() override { return false; } + + ~EnterOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(EnterOp); +}; + +REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_CPU), EnterOp); +REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Enter").Device(DEVICE_GPU).TypeConstraint("T"), EnterOp); +#define REGISTER_GPU_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefEnter").Device(DEVICE_GPU).TypeConstraint("T"), EnterOp); + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +TF_CALL_NUMBER_TYPES(REGISTER_GPU_REF_KERNEL); + +#undef REGISTER_GPU_KERNEL +#undef REGISTER_GPU_REF_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Enter") + .Device(DEVICE_GPU) + .HostMemory("data") + .HostMemory("output") + .TypeConstraint("T"), + EnterOp); + +// An exit op has one input and one output. It exits the current +// frame to its parent frame, and makes its input available to the +// parent frame. +class ExitOp : public OpKernel { + public: + explicit ExitOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + context->set_output(0, context->input(0)); + } + + bool IsExpensive() override { return false; } + + ~ExitOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(ExitOp); +}; + +REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Exit").Device(DEVICE_GPU).TypeConstraint("T"), ExitOp); + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); + +#undef REGISTER_GPU_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Exit") + .Device(DEVICE_GPU) + .HostMemory("data") + .HostMemory("output") + .TypeConstraint("T"), + ExitOp); + +// A next_iteration op has one input and one output. It makes its input +// available to the next iteration. +class NextIterationOp : public OpKernel { + public: + explicit NextIterationOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + context->set_output(0, context->input(0)); + } + + bool IsExpensive() override { return false; } + + ~NextIterationOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp); +}; + +REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU), + NextIterationOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("NextIteration").Device(DEVICE_GPU).TypeConstraint("T"), \ + NextIterationOp); + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); + +#undef REGISTER_GPU_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("NextIteration") + .Device(DEVICE_GPU) + .HostMemory("data") + .HostMemory("output") + .TypeConstraint("T"), + NextIterationOp); + +// A LoopCond op has one input and one output. The input is a boolean +// scalar representing the taken branches of the "pivot" Switch that +// determines loop termination. As a contract, any high-level front-end +// should always use port '0' of the "pivot" switches for loop exit. +class LoopCondOp : public OpKernel { + public: + explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + context->set_output(0, context->input(0)); + } + + bool IsExpensive() override { return false; } + + ~LoopCondOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); +}; + +REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp); +REGISTER_KERNEL_BUILDER(Name("LoopCond") + .Device(DEVICE_GPU) + .HostMemory("input") + .HostMemory("output"), + LoopCondOp); + +// ControlTrigger kernels +REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU), + ControlTriggerOp); + +REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU), + ControlTriggerOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h new file mode 100644 index 0000000000..184cc9fb63 --- /dev/null +++ b/tensorflow/core/kernels/control_flow_ops.h @@ -0,0 +1,22 @@ +#ifndef TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_ +#define TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +// A ControlTriggerOp is similar to a NoOp. However, it always treats the input +// control edges as Live edges. Its primary use so far is in the scheduling of +// recvs, where we add ControlTrigger nodes and use them to trigger recvs. We +// allow ControlTrigger nodes to be enabled by dead nodes. +class ControlTriggerOp : public OpKernel { + public: + explicit ControlTriggerOp(OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} + bool IsExpensive() override { return false; } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_ diff --git a/tensorflow/core/kernels/control_flow_ops_test.cc b/tensorflow/core/kernels/control_flow_ops_test.cc new file mode 100644 index 0000000000..52bc11abf0 --- /dev/null +++ b/tensorflow/core/kernels/control_flow_ops_test.cc @@ -0,0 +1,71 @@ +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/public/tensor.h" +#include + +namespace tensorflow { +namespace { + +// Tests for the switch op +class SwitchOpTest : public OpsTestBase { + protected: + void Initialize(DataType dt) { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("op", "Switch") + .Input(FakeInput(dt)) + .Input(FakeInput()) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(SwitchOpTest, Int32Success_6_s0) { + Initialize(DT_INT32); + AddInputFromArray(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + AddInputFromArray(TensorShape({}), {false}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({6})); + test::FillValues(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + EXPECT_EQ(nullptr, GetOutput(1)); +} + +TEST_F(SwitchOpTest, Int32Success_6_s1) { + Initialize(DT_INT32); + AddInputFromArray(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + AddInputFromArray(TensorShape({}), {true}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({6})); + test::FillValues(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual(expected, *GetOutput(1)); + EXPECT_EQ(nullptr, GetOutput(0)); +} + +TEST_F(SwitchOpTest, Int32Success_2_3_s0) { + Initialize(DT_INT32); + AddInputFromArray(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6}); + AddInputFromArray(TensorShape({}), {false}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({2, 3})); + test::FillValues(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + EXPECT_EQ(nullptr, GetOutput(1)); +} + +TEST_F(SwitchOpTest, StringSuccess_s1) { + Initialize(DT_STRING); + AddInputFromArray(TensorShape({6}), {"A", "b", "C", "d", "E", "f"}); + AddInputFromArray(TensorShape({}), {true}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({6})); + test::FillValues(&expected, {"A", "b", "C", "d", "E", "f"}); + test::ExpectTensorEqual(expected, *GetOutput(1)); + EXPECT_EQ(nullptr, GetOutput(0)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h new file mode 100644 index 0000000000..2fb623244c --- /dev/null +++ b/tensorflow/core/kernels/conv_2d.h @@ -0,0 +1,127 @@ +#ifndef TENSORFLOW_KERNELS_CONV_2D_H_ +#define TENSORFLOW_KERNELS_CONV_2D_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// TODO(yangke): revisit these operations and in particular, see if we can +// combine all of them into just one operation without causing nvcc to +// timeout. +template +struct ShuffleAndReverse { + void operator()(const Device& d, typename TTypes::ConstTensor input, + const Eigen::DSizes& order, + const Eigen::array& reverse_dims, + typename TTypes::Tensor output) { + output.device(d) = input.shuffle(order).reverse(reverse_dims); + } +}; + +template +struct InflatePadAndShuffle { + void operator()( + const Device& d, typename TTypes::ConstTensor input, + const Eigen::DSizes& strides, + const Eigen::array, Dims>& pad_dims, + const Eigen::DSizes& order, + typename TTypes::Tensor output) { + output.device(d) = input.inflate(strides).pad(pad_dims).shuffle(order); + } +}; + +template +void SpatialConvolutionFunc(const Device& d, Output output, Input input, + Filter filter, int stride, + const Eigen::PaddingType& padding) { + output.device(d) = Eigen::SpatialConvolution(input, filter, stride, padding); +} + +template +struct SpatialConvolution { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, int stride, + const Eigen::PaddingType& padding) { + SpatialConvolutionFunc(d, output, input, filter, stride, padding); + } +}; + +template +struct SpatialConvolutionBackwardInput { + void operator()(const Device& d, typename TTypes::Tensor input_backward, + typename TTypes::ConstTensor kernel, + typename TTypes::ConstTensor output_backward, + int input_rows, int input_cols, int stride) { + input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput( + kernel, output_backward, input_rows, input_cols, stride); + } +}; + +template +struct SpatialConvolutionBackwardKernel { + void operator()(const Device& d, + typename TTypes::Tensor kernel_backward, + typename TTypes::ConstTensor input, + typename TTypes::ConstTensor output_backward, + int kernel_rows, int kernel_cols, int stride) { + kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel( + input, output_backward, kernel_rows, kernel_cols, stride); + } +}; + +// TODO(vrv): Figure out how to use the MatMulFunctor in matmul_op.h. +// My initial attempt to do this compiled but failed in the pytest +// due to a swigdeps error. +template +struct MatMulConvFunctor { + // Computes on device "d": out = in0 * in1, where * is matrix + // multiplication. + void operator()( + const Device& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in0, + typename TTypes::ConstTensor in1, + const Eigen::array, 1>& dim_pair) { + out.device(d) = in0.contract(in1, dim_pair); + } +}; + +template +struct TransformFilter { + void operator()(const Device& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out) { + out.device(d) = in.shuffle(Eigen::DSizes(3, 2, 0, 1)); + } +}; + +template +struct TransformDepth { + void operator()(const Device& d, typename TTypes::ConstTensor in, + const Eigen::DSizes& shuffle, + typename TTypes::Tensor out) { + out.device(d) = in.shuffle(shuffle); + } +}; + +template +struct PadInput { + void operator()(const Device& d, typename TTypes::ConstTensor in, + int padding_rows_left, int padding_rows_right, + int padding_cols_left, int padding_cols_right, + typename TTypes::Tensor out) { + Eigen::array, 4> padding; + padding[0] = std::make_pair(0, 0); + padding[1] = std::make_pair(padding_rows_left, padding_rows_right); + padding[2] = std::make_pair(padding_cols_left, padding_cols_right); + padding[3] = std::make_pair(0, 0); + out.device(d) = in.pad(padding); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_CONV_2D_H_ diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc new file mode 100644 index 0000000000..bb21d7003c --- /dev/null +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -0,0 +1,1190 @@ +// See docs in ../ops/nn_ops.cc. + +#define USE_EIGEN_TENSOR +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/util/use_cudnn.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/public/tensor.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/stream_executor/stream.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// The operation to compute Conv2D gradients. +// +// +// To compute the gradients for Conv2D, we need three input tensors: +// input, filter, and backprop for output. +// And we need to compute two backprops: one for input and one for filter. We +// compute them in two different kernels. + +// Both backprops can be computed as straightforward conv2d. +// +// Consider a case where the input is 3x3 and the filter is 2x1: +// +// INPUT = [ A B C ] +// [ D E F ] +// [ G H I ] +// +// where each "A", "B", etc is batch x in_depth +// +// FILTER = [ X Y ] +// +// where both "X" and "Y" are in_depth x out_depth +// +// With VALID padding, the output is 3x2: +// +// OUTPUT = [ a b ] +// [ c d ] +// [ e f ] +// +// where each "a", "b", etc is batch x out_depth +// +// So we have: +// +// a = A * X + B * Y +// b = B * X + C * Y +// c = D * X + E * Y +// d = E * X + F * Y +// e = G * X + H * Y +// f = H * X + I * Y +// +// So when we have backprops for the outputs (we denote them by +// a', b', ... ): +// +// The backprops for the input are: +// +// A' = a' * X^t +// B' = a' * Y^t + b' * X^t +// C' = b' * Y^t +// ... +// +// This is essentially computing a 2d conv of +// +// INPUT = [ 0 a' b' 0 ] +// [ 0 c' d' 0 ] +// [ 0 e' f' 0 ] +// and +// +// FILTER = [ Y^t X^t ] +// +// The backprops for the filter are: +// +// X' = A^t * a' + B^t * b' + D^t * c' + E^t * d' + G^t * e' + H^t * f' +// Y' = B^t * a' + C^t * b' + E^t + c' + F^t * d' + H^t * e' + I^t * f' +// +// This is essentially computing a 2d conv of +// +// INPUT = [ A^t B^t C^t ] +// [ D^t E^t F^t ] +// [ G^t H^t I^t ] +// +// and +// +// FILTER = [ a' b' ] +// [ c' d' ] +// [ e' f' ] +// +// +////////////////////////////////////////////////////////// +// +// With stride more than one, it's a bit more complicated (we will need to +// create holes to the backprop). +// +// Consider the case where +// +// INPUT = [ A B C D E ] +// [ F G H I J ] +// [ K L M N O ] +// and +// +// FILTER = [ X Y Z ] +// +// with stride 2. +// +// The output will be +// +// OUTPUT = [ a b ] +// [ c d ] +// +// where: +// +// a = A * X + B * Y + C * Z +// b = C * X + D * Y + E * Z +// c = K * X + L * Y + M * Z +// d = M * X + N * Y + O * Z +// +// +// To compute the backprop for INPUT, we need to convolve +// +// INPUT = [ 0 0 a' 0 b' 0 0 ] +// [ 0 0 0 0 0 0 0 ] +// [ 0 0 c' 0 d' 0 0 ] +// +// (notice the holes in INPUT) +// +// and +// +// FILTER = [ Z^t Y^t X^t ] +// +// with stride 1. +// +// To compute the backprop for FILTER, we need to convolve + +// +// INPUT = [ A^t B^t C^t D^t E^t ] +// [ F^t G^t H^t I^t J^t ] +// [ K^t L^t M^t N^t O^t ] +// and +// +// FILTER = [ a' 0 b' ] +// [ 0 0 0 ] +// [ c' 0 d' ] +// +// (notice the holes in FILTER) +// +// +// with stride 1 +// +////////////////////////////////////////////////////////// +// +// +// The case for SAME padding is in fact very similar to VALID -- we just +// need to pad the input tensor a bit when computing the filter_backprop. + +// Common code between the two kernels: verifies that the dimensions all match +// and extract the padded rows and columns. +#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \ + const Tensor& out_backprop = context->input(2); \ + OP_REQUIRES( \ + context, input_shape.dims() == 4, \ + errors::InvalidArgument(label, ": input must be 4-dimensional")); \ + OP_REQUIRES( \ + context, filter_shape.dims() == 4, \ + errors::InvalidArgument(label, ": filter must be 4-dimensional")); \ + OP_REQUIRES( \ + context, out_backprop.dims() == 4, \ + errors::InvalidArgument(label, ": out_backprop must be 4-dimensional")); \ + const int64 batch = input_shape.dim_size(0); \ + OP_REQUIRES( \ + context, batch == out_backprop.dim_size(0), \ + errors::InvalidArgument( \ + label, ": input and out_backprop must have the same batch size")); \ + const int64 input_rows = input_shape.dim_size(1); \ + const int64 input_cols = input_shape.dim_size(2); \ + const int64 filter_rows = filter_shape.dim_size(0); \ + const int64 filter_cols = filter_shape.dim_size(1); \ + const int64 output_rows = out_backprop.dim_size(1); \ + const int64 output_cols = out_backprop.dim_size(2); \ + const int64 in_depth = input_shape.dim_size(3); \ + OP_REQUIRES(context, in_depth == filter_shape.dim_size(2), \ + errors::InvalidArgument( \ + label, ": input and filter must have the same depth")); \ + const int64 out_depth = filter_shape.dim_size(3); \ + OP_REQUIRES( \ + context, out_depth == out_backprop.dim_size(3), \ + errors::InvalidArgument( \ + label, ": filter and out_backprop must have the same out_depth")); \ + const auto stride = strides_[1]; \ + int out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; \ + if (filter_cols == filter_rows && filter_rows == 1 && stride == 1) { \ + out_rows = input_rows; \ + out_cols = input_cols; \ + } else { \ + OP_REQUIRES_OK( \ + context, Get2dOutputSize(input_rows, input_cols, filter_rows, \ + filter_cols, stride, stride, padding_, \ + &out_rows, &out_cols, &pad_rows, &pad_cols)); \ + } \ + OP_REQUIRES( \ + context, output_rows == out_rows, \ + errors::InvalidArgument( \ + label, ": Number of rows of out_backprop doesn't match computed: ", \ + "actual = ", output_rows, ", computed = ", out_rows)); \ + OP_REQUIRES( \ + context, output_cols == out_cols, \ + errors::InvalidArgument( \ + label, ": Number of cols of out_backprop doesn't match computed: ", \ + "actual = ", output_cols, ", computed = ", out_cols)); \ + const auto expanded_out_rows = (output_rows - 1) * stride + 1; \ + const auto expanded_out_cols = (output_cols - 1) * stride + 1; \ + const auto padded_out_rows = input_rows + filter_rows - 1; \ + const auto padded_out_cols = input_cols + filter_cols - 1; \ + const auto top_pad_rows = filter_rows - 1 - pad_rows; \ + const auto left_pad_cols = filter_cols - 1 - pad_cols; \ + const auto bottom_pad_rows = \ + padded_out_rows - expanded_out_rows - top_pad_rows; \ + const auto right_pad_cols = \ + padded_out_cols - expanded_out_cols - left_pad_cols; \ + Eigen::DSizes strides{1, stride, stride, 1}; \ + VLOG(2) << "Conv2d: " << label \ + << ": expanded_out_rows = " << expanded_out_rows \ + << ", expanded_out_cols = " << expanded_out_cols \ + << ", filter_rows = " << filter_rows \ + << ", filter_cols = " << filter_cols \ + << ", padded_out_rows = " << padded_out_rows \ + << ", padded_out_cols = " << padded_out_cols \ + << ", top_pad_rows = " << top_pad_rows \ + << ", left_pad_cols = " << left_pad_cols \ + << ", bottom_pad_rows = " << bottom_pad_rows \ + << ", right_pad_cols = " << right_pad_cols \ + << ", strides = " << strides[1] + +namespace { +TensorShape VectorToShape(const TTypes::ConstVec& sizes) { + TensorShape shape; + + using Index = TTypes::ConstVec::Index; + const Index dims = sizes.size(); + for (Index i = 0; i < dims; ++i) { + shape.AddDim(sizes(i)); + } + + return shape; +} +} // namespace + +// The fast versions using eigen computations directly. They are only enabled +// for CPU for now since nvcc times out when trying to compile them. +// TODO(yangke): enable them for GPUs when we have a faster compiler. + +template +class Conv2DFastBackpropInputOp : public OpKernel { + public: + explicit Conv2DFastBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument( + "Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(context, strides_[1] == strides_[2], + errors::InvalidArgument( + "Current implementation only supports equal length " + "strides in the row and column dimensions.")); + OP_REQUIRES(context, (strides_[0] == 1 && strides_[3] == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_sizes = context->input(0); + const Tensor& filter = context->input(1); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(input_sizes.shape()), + errors::InvalidArgument( + "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", + input_sizes.dims())); + TensorShape input_shape = VectorToShape(input_sizes.vec()); + const TensorShape& filter_shape = filter.shape(); + + EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropInput"); + Tensor* in_backprop = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_shape, &in_backprop)); + // Need to flip the input_rows and input_cols when passing to eigen. + functor::SpatialConvolutionBackwardInput()( + context->eigen_device(), in_backprop->tensor(), + filter.tensor(), out_backprop.tensor(), input_cols, + input_rows, stride); + } + + private: + std::vector strides_; + Padding padding_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DFastBackpropInputOp); +}; + +// Based on implementation written by Yangqing Jia (jiayq). +template +class Conv2DCustomBackpropInputOp : public OpKernel { + public: + explicit Conv2DCustomBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(context, strides_[1] == strides_[2], + errors::InvalidArgument( + "Current implementation only supports equal length " + "strides in the row and column dimensions.")); + OP_REQUIRES( + context, (strides_[0] == 1 && strides_[3] == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_sizes = context->input(0); + const Tensor& filter = context->input(1); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(input_sizes.shape()), + errors::InvalidArgument( + "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", + input_sizes.dims())); + TensorShape input_shape = VectorToShape(input_sizes.vec()); + const TensorShape& filter_shape = filter.shape(); + + EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropInput"); + Tensor* in_backprop = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_shape, &in_backprop)); + + // TODO(andydavis) Consider moving code shared with + // Conv2DCustomBackpropFilterOp into a shared helper function. + int pad_top; + int pad_bottom; + int pad_left; + int pad_right; + OP_REQUIRES_OK( + context, + Get2dOutputSizeVerbose(input_rows, input_cols, filter_rows, filter_cols, + stride, stride, padding_, &out_rows, &out_cols, + &pad_top, &pad_bottom, &pad_left, &pad_right)); + + // The total dimension size of each kernel. + const int filter_total_size = filter_rows * filter_cols * in_depth; + // The output image size is the spatial size of the output. + const int output_image_size = out_rows * out_cols; + + Tensor col_buffer; + OP_REQUIRES_OK( + context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({output_image_size, filter_total_size}), &col_buffer)); + + // The input offset corresponding to a single input image. + const int input_offset = input_rows * input_cols * in_depth; + // The output offset corresponding to a single output image. + const int output_offset = out_rows * out_cols * out_depth; + + auto* filter_data = filter.template flat().data(); + auto* col_buffer_data = col_buffer.template flat().data(); + auto* out_backprop_data = out_backprop.template flat().data(); + auto* input_backprop_data = in_backprop->template flat().data(); + + typedef Eigen::Map> MatrixMap; + typedef Eigen::Map> ConstMatrixMap; + + for (int image_id = 0; image_id < batch; ++image_id) { + // Compute gradient into col_buffer. + MatrixMap C(col_buffer_data, output_image_size, filter_total_size); + + ConstMatrixMap A(out_backprop_data + output_offset * image_id, + output_image_size, out_depth); + ConstMatrixMap B(filter_data, filter_total_size, out_depth); + + // TODO(andydavis) Use a multi-threaded matmul implementation here. + C.noalias() = A * B.transpose(); + + Col2im(col_buffer_data, in_depth, input_rows, input_cols, filter_rows, + filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride, + stride, input_backprop_data); + + input_backprop_data += input_offset; + } + } + + private: + std::vector strides_; + Padding padding_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp); +}; + +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + Conv2DCustomBackpropInputOp); + +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") + .Device(DEVICE_CPU) + .Label("custom") + .TypeConstraint("T"), + Conv2DCustomBackpropInputOp); + +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") + .Device(DEVICE_CPU) + .Label("eigen_tensor") + .TypeConstraint("T"), + Conv2DFastBackpropInputOp); + +template +class Conv2DFastBackpropFilterOp : public OpKernel { + public: + explicit Conv2DFastBackpropFilterOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument( + "Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(context, strides_[1] == strides_[2], + errors::InvalidArgument( + "Current implementation only supports equal length " + "strides in the row and column dimensions.")); + OP_REQUIRES(context, (strides_[0] == 1 && strides_[3] == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& filter_sizes = context->input(1); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(filter_sizes.shape()), + errors::InvalidArgument( + "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ", + filter_sizes.dims())); + const TensorShape& input_shape = input.shape(); + TensorShape filter_shape = VectorToShape(filter_sizes.vec()); + + EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropFilter"); + Tensor* filter_backprop = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, filter_shape, &filter_backprop)); + + // Need to flip the filter_rows and filter_cols when passing to eigen. + functor::SpatialConvolutionBackwardKernel()( + context->eigen_device(), filter_backprop->tensor(), + input.tensor(), out_backprop.tensor(), filter_cols, + filter_rows, stride); + } + + private: + std::vector strides_; + Padding padding_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DFastBackpropFilterOp); +}; + +// Based on implementation written by Yangqing Jia (jiayq). +template +class Conv2DCustomBackpropFilterOp : public OpKernel { + public: + explicit Conv2DCustomBackpropFilterOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(context, strides_[1] == strides_[2], + errors::InvalidArgument( + "Current implementation only supports equal length " + "strides in the row and column dimensions.")); + OP_REQUIRES( + context, (strides_[0] == 1 && strides_[3] == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& filter_sizes = context->input(1); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(filter_sizes.shape()), + errors::InvalidArgument( + "Conv2DCustomBackpropFilter: filter_sizes input must be 1-dim, " + "not ", + filter_sizes.dims())); + const TensorShape& input_shape = input.shape(); + TensorShape filter_shape = VectorToShape(filter_sizes.vec()); + + EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DCustomBackpropFilter"); + Tensor* filter_backprop; + OP_REQUIRES_OK(context, + context->allocate_output(0, filter_shape, &filter_backprop)); + + int pad_top; + int pad_bottom; + int pad_left; + int pad_right; + OP_REQUIRES_OK( + context, + Get2dOutputSizeVerbose(input_rows, input_cols, filter_rows, filter_cols, + stride, stride, padding_, &out_rows, &out_cols, + &pad_top, &pad_bottom, &pad_left, &pad_right)); + + // The total dimension size of each kernel. + const int filter_total_size = filter_rows * filter_cols * in_depth; + // The output image size is the spatial size of the output. + const int output_image_size = out_rows * out_cols; + + Tensor col_buffer; + OP_REQUIRES_OK( + context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({output_image_size, filter_total_size}), &col_buffer)); + + // The input offset corresponding to a single input image. + const int input_offset = input_rows * input_cols * in_depth; + // The output offset corresponding to a single output image. + const int output_offset = out_rows * out_cols * out_depth; + + auto* input_data = input.template flat().data(); + auto* col_buffer_data = col_buffer.template flat().data(); + auto* out_backprop_data = out_backprop.template flat().data(); + auto* filter_backprop_data = filter_backprop->template flat().data(); + + typedef Eigen::Map> MatrixMap; + typedef Eigen::Map> ConstMatrixMap; + + MatrixMap C(filter_backprop_data, filter_total_size, out_depth); + + C.setZero(); + for (int image_id = 0; image_id < batch; ++image_id) { + // When we compute the gradient with respect to the filters, we need to do + // im2col to allow gemm-type computation. + Im2col(input_data, in_depth, input_rows, input_cols, filter_rows, + filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride, + stride, col_buffer_data); + + ConstMatrixMap A(col_buffer_data, output_image_size, filter_total_size); + ConstMatrixMap B(out_backprop_data + output_offset * image_id, + output_image_size, out_depth); + + // Compute gradient with respect to filter. + // TODO(andydavis) Use a multi-threaded matmul implementation here. + C.noalias() += A.transpose() * B; + + input_data += input_offset; + } + } + + private: + std::vector strides_; + Padding padding_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropFilterOp); +}; + +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + Conv2DCustomBackpropFilterOp); + +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") + .Device(DEVICE_CPU) + .Label("custom") + .TypeConstraint("T"), + Conv2DCustomBackpropFilterOp); + +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") + .Device(DEVICE_CPU) + .Label("eigen_tensor") + .TypeConstraint("T"), + Conv2DFastBackpropFilterOp); + +// GPU definitions of both ops. +#if GOOGLE_CUDA +namespace { +template +perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory, + uint64 size) { + perftools::gputools::DeviceMemoryBase wrapped(const_cast(cuda_memory), + size * sizeof(T)); + perftools::gputools::DeviceMemory typed(wrapped); + return typed; +} +} // namespace + +// The slow version (but compiles for GPU) + +// Backprop for input. +template +class Conv2DSlowBackpropInputOp : public OpKernel { + public: + explicit Conv2DSlowBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument( + "Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(context, strides_[1] == strides_[2], + errors::InvalidArgument( + "Current implementation only supports equal length " + "strides in the row and column dimensions.")); + OP_REQUIRES(context, (strides_[0] == 1 && strides_[3] == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); + use_cudnn_ &= CanUseCudnn(); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_sizes = context->input(0); + const Tensor& filter = context->input(1); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(input_sizes.shape()), + errors::InvalidArgument( + "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", + input_sizes.dims())); + TensorShape input_shape = VectorToShape(input_sizes.vec()); + const TensorShape& filter_shape = filter.shape(); + + EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropInput"); + Tensor* in_backprop = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_shape, &in_backprop)); + + const int padding_rows = + (output_rows - 1) * stride + filter_rows - input_rows; + const int padding_cols = + (output_cols - 1) * stride + filter_cols - input_cols; + + // TODO(keveman): cuDNN only supports equal padding on both sides, so only + // calling it when that is true. Remove this check when (if?) cuDNN starts + // supporting different padding. + bool padding_compatible = + (padding_rows % 2 == 0) && (padding_cols % 2 == 0); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + if (use_cudnn_ && padding_compatible) { + if (filter_rows == 1 && filter_cols == 1 && stride == 1) { + // 1x1 filter, so call cublas directly. + const uint64 m = batch * input_rows * input_cols; + const uint64 k = out_depth; + const uint64 n = in_depth; + + auto a_ptr = AsDeviceMemory(out_backprop.template flat().data(), + out_backprop.template flat().size()); + auto b_ptr = AsDeviceMemory(filter.template flat().data(), + filter.template flat().size()); + auto c_ptr = AsDeviceMemory(in_backprop->template flat().data(), + in_backprop->template flat().size()); + + auto transpose = perftools::gputools::blas::Transpose::kTranspose; + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + + bool blas_launch_status = + stream->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, + k, a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", + m, ", n=", n, ", k=", k)); + } + return; + } + + perftools::gputools::dnn::BatchDescriptor input_desc; + input_desc.set_count(batch) + .set_height(input_rows) + .set_width(input_cols) + .set_feature_map_count(in_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::BatchDescriptor output_desc; + output_desc.set_count(batch) + .set_height(output_rows) + .set_width(output_cols) + .set_feature_map_count(out_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::FilterDescriptor filter_desc; + filter_desc.set_input_filter_height(filter_rows) + .set_input_filter_width(filter_cols) + .set_input_feature_map_count(in_depth) + .set_output_feature_map_count(out_depth); + perftools::gputools::dnn::ConvolutionDescriptor conv_desc; + conv_desc.set_vertical_filter_stride(stride) + .set_horizontal_filter_stride(stride) + .set_zero_padding_height(padding_rows / 2) + .set_zero_padding_width(padding_cols / 2); + + // NOTE(keveman): + // cuDNN only supports the following layouts : + // Input : B x D x R x C + // Filter : OD x ID x R x C + // Whereas, we have + // Input : B x R x C x D + // Filter : R x C x ID x OD + // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C) + // The first TransformDepth performs + // (B x R x C x D) => (B x D x R x C). + // Since the tensor returned from cuDNN is B x D x R x C also, + // the second TransformDepth performs + // (B x D x R x C) => (B x R x C x D). + Tensor transformed_filter; + OP_REQUIRES_OK( + context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({out_depth, in_depth, filter_rows, filter_cols}), + &transformed_filter)); + + functor::TransformFilter()(context->eigen_device(), + filter.tensor(), + transformed_filter.tensor()); + + Tensor transformed_out_backprop; + OP_REQUIRES_OK( + context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({batch, out_depth, output_rows, output_cols}), + &transformed_out_backprop)); + + functor::TransformDepth()( + context->eigen_device(), out_backprop.tensor(), + Eigen::DSizes(0, 3, 1, 2), + transformed_out_backprop.tensor()); + + Tensor pre_transformed_in_backprop; + OP_REQUIRES_OK(context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({batch, in_depth, input_rows, input_cols}), + &pre_transformed_in_backprop)); + + auto out_backprop_ptr = + AsDeviceMemory(transformed_out_backprop.template flat().data(), + transformed_out_backprop.template flat().size()); + auto filter_ptr = + AsDeviceMemory(transformed_filter.template flat().data(), + transformed_filter.template flat().size()); + auto in_backprop_ptr = + AsDeviceMemory(pre_transformed_in_backprop.template flat().data(), + pre_transformed_in_backprop.template flat().size()); + + bool cudnn_launch_status = + stream->ThenConvolveBackwardData(filter_desc, filter_ptr, output_desc, + out_backprop_ptr, conv_desc, + input_desc, &in_backprop_ptr) + .ok(); + + if (!cudnn_launch_status) { + context->SetStatus(errors::Internal( + "cuDNN Backward Data function launch failure : input shape(", + input_shape.DebugString(), ") filter shape(", + filter_shape.DebugString(), ")")); + } + + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + functor::TransformDepth()( + context->eigen_device(), + toConstTensor(pre_transformed_in_backprop).template tensor(), + Eigen::DSizes(0, 2, 3, 1), + in_backprop->tensor()); + } else { + // We fill out a padded out_backprop + TensorShape padded_out_shape( + {batch, padded_out_rows, padded_out_cols, out_depth}); + Tensor padded_output; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::v(), + padded_out_shape, &padded_output)); + + Eigen::DSizes trivial_order{0, 1, 2, 3}; + Eigen::array, 4> pad_dims{ + {{0, 0}, + {top_pad_rows, bottom_pad_rows}, + {left_pad_cols, right_pad_cols}, + {0, 0}}}; + + functor::InflatePadAndShuffle()( + context->eigen_device(), out_backprop.tensor(), strides, + pad_dims, trivial_order, padded_output.tensor()); + const Tensor& padded_output_cref = padded_output; + + // We then need to fill a new "reverted" filter + // We need to transpose the in_depth and out_depth for the filter and + // inverse the rows and cols. + TensorShape r_filter_shape( + {filter_rows, filter_cols, out_depth, in_depth}); + Tensor r_filter; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::v(), + r_filter_shape, &r_filter)); + + Eigen::DSizes filter_order{0, 1, 3, 2}; + Eigen::array filter_rev_dims{true, true, false, false}; + functor::ShuffleAndReverse()( + context->eigen_device(), filter.tensor(), filter_order, + filter_rev_dims, r_filter.tensor()); + const Tensor& r_filter_cref = r_filter; + + // Now we can call conv_2d directly. + functor::SpatialConvolution()( + context->eigen_device(), in_backprop->tensor(), + padded_output_cref.tensor(), r_filter_cref.tensor(), 1, + BrainPadding2EigenPadding(VALID)); + } + } + + private: + std::vector strides_; + Padding padding_; + bool use_cudnn_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp); +}; + +// Backprop for filter. +template +class Conv2DSlowBackpropFilterOp : public OpKernel { + public: + explicit Conv2DSlowBackpropFilterOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument( + "Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(context, strides_[1] == strides_[2], + errors::InvalidArgument( + "Current implementation only supports equal length " + "strides in the row and column dimensions.")); + OP_REQUIRES(context, (strides_[0] == 1 && strides_[3] == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); + use_cudnn_ &= CanUseCudnn(); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& filter_sizes = context->input(1); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(filter_sizes.shape()), + errors::InvalidArgument( + "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ", + filter_sizes.dims())); + const TensorShape& input_shape = input.shape(); + TensorShape filter_shape = VectorToShape(filter_sizes.vec()); + + EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropFilter"); + Tensor* filter_backprop = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, filter_shape, &filter_backprop)); + + const int padding_rows = + (output_rows - 1) * stride + filter_rows - input_rows; + const int padding_cols = + (output_cols - 1) * stride + filter_cols - input_cols; + + // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only + // calling it when that is true. Remove this check when (if?) cuDNN starts + // supporting different padding. + bool padding_compatible = + (padding_rows % 2 == 0) && (padding_cols % 2 == 0); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + if (use_cudnn_ && padding_compatible) { + if (filter_rows == 1 && filter_cols == 1 && stride == 1) { + const uint64 m = in_depth; + const uint64 k = batch * input_rows * input_cols; + const uint64 n = out_depth; + + // The shape of output backprop is + // [batch, out_rows, out_cols, out_depth] + // From cublas's perspective, it is: n x k + auto a_ptr = AsDeviceMemory(out_backprop.template flat().data(), + out_backprop.template flat().size()); + + // The shape of input is + // [batch, in_rows, in_cols, in_depth], + // From cublas's perspective, it is: m x k + auto b_ptr = AsDeviceMemory(input.template flat().data(), + input.template flat().size()); + + // the shape of the filter backprop from the conv_2d should be + // [1, 1, in_depth, out_depth] + // From cublas's perspective, it is: n x m + auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), + filter_backprop->template flat().size()); + + bool blas_launch_status = + stream->ThenBlasGemm( + perftools::gputools::blas::Transpose::kNoTranspose, + perftools::gputools::blas::Transpose::kTranspose, n, m, k, + 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", + m, ", n=", n, ", k=", k)); + } + return; + } + + perftools::gputools::dnn::BatchDescriptor input_desc; + input_desc.set_count(batch) + .set_height(input_rows) + .set_width(input_cols) + .set_feature_map_count(in_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::BatchDescriptor output_desc; + output_desc.set_count(batch) + .set_height(output_rows) + .set_width(output_cols) + .set_feature_map_count(out_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::FilterDescriptor filter_desc; + filter_desc.set_input_filter_height(filter_rows) + .set_input_filter_width(filter_cols) + .set_input_feature_map_count(in_depth) + .set_output_feature_map_count(out_depth); + perftools::gputools::dnn::ConvolutionDescriptor conv_desc; + conv_desc.set_vertical_filter_stride(stride) + .set_horizontal_filter_stride(stride) + .set_zero_padding_height(padding_rows / 2) + .set_zero_padding_width(padding_cols / 2); + + // NOTE(zhengxq): + // cuDNN only supports the following layouts : + // Input : B x D x R x C + // Filter : OD x ID x R x C + // Whereas, we have + // Input : B x R x C x D + // Filter : R x C x ID x OD + // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C) + // The first TransformDepth performs + // (B x R x C x D) => (B x D x R x C). + // Since the tensor returned from cuDNN is B x D x R x C also, + // the second TransformDepth performs + // (B x D x R x C) => (B x R x C x D). + + Tensor pre_transformed_filter_backprop; + OP_REQUIRES_OK( + context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({out_depth, in_depth, filter_rows, filter_cols}), + &pre_transformed_filter_backprop)); + + Tensor transformed_out_backprop; + OP_REQUIRES_OK( + context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({batch, out_depth, output_rows, output_cols}), + &transformed_out_backprop)); + + functor::TransformDepth()( + context->eigen_device(), out_backprop.tensor(), + Eigen::DSizes(0, 3, 1, 2), + transformed_out_backprop.tensor()); + + Tensor transformed_input; + OP_REQUIRES_OK(context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({batch, in_depth, input_rows, input_cols}), + &transformed_input)); + + functor::TransformDepth()( + context->eigen_device(), input.tensor(), + Eigen::DSizes(0, 3, 1, 2), + transformed_input.tensor()); + + auto out_backprop_ptr = + AsDeviceMemory(transformed_out_backprop.template flat().data(), + transformed_out_backprop.template flat().size()); + auto filter_backprop_ptr = AsDeviceMemory( + pre_transformed_filter_backprop.template flat().data(), + pre_transformed_filter_backprop.template flat().size()); + auto input_ptr = + AsDeviceMemory(transformed_input.template flat().data(), + transformed_input.template flat().size()); + + bool cudnn_launch_status = + stream->ThenConvolveBackwardFilter(input_desc, input_ptr, output_desc, + out_backprop_ptr, conv_desc, + filter_desc, &filter_backprop_ptr) + .ok(); + + if (!cudnn_launch_status) { + context->SetStatus(errors::Internal( + "cuDNN Backward Filter function launch failure : input shape(", + input_shape.DebugString(), ") filter shape(", + filter_shape.DebugString(), ")")); + } + + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + functor::TransformDepth()( + context->eigen_device(), + toConstTensor(pre_transformed_filter_backprop) + .template tensor(), + Eigen::DSizes(2, 3, 1, 0), + filter_backprop->tensor()); + } else { + // Fall back to the non-cudnn code path + + // For the backprop of the filter, we need to also transpose the + // out_backprop. + // The shape of backprop is + // [batch, out_rows, out_cols, out_depth] + // And we need to change it to + // [out_depth, out_rows, out_cols, batch] + Eigen::DSizes out_order{3, 1, 2, 0}; + TensorShape padded_out_shape( + {out_depth, padded_out_rows, padded_out_cols, batch}); + Tensor padded_output; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::v(), + padded_out_shape, &padded_output)); + + Eigen::array, 4> pad_dims{ + {{0, 0}, + {top_pad_rows, bottom_pad_rows}, + {left_pad_cols, right_pad_cols}, + {0, 0}}}; + functor::InflatePadAndShuffle()( + context->eigen_device(), out_backprop.tensor(), strides, + pad_dims, out_order, padded_output.tensor()); + const Tensor& padded_output_cref = padded_output; + + // For the backprop of the filter, we need to transpose the input. + // The shape of input is + // [batch, in_rows, in_cols, in_depth] + // And we need to change it to + // [in_rows, in_cols, batch, in_depth] + Eigen::DSizes in_order{1, 2, 0, 3}; + TensorShape in_shuffle_shape({input_rows, input_cols, batch, in_depth}); + Tensor in_shuffle; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::v(), + in_shuffle_shape, &in_shuffle)); + + // No need for reversing this time. + Eigen::array trivial_dims{false, false, false, false}; + functor::ShuffleAndReverse()( + context->eigen_device(), input.tensor(), in_order, + trivial_dims, in_shuffle.tensor()); + const Tensor& in_shuffle_cref = in_shuffle; + + // The output of the conv_2d would be + // [out_depth, filter_rows, filter_cols, in_depth] + // and we need to shuffle it back to + // [filter_rows, filter_cols, in_depth, out_depth]; + // And we need to reverse the filter backprops + // So we need to allocated (sigh) yet another piece of memory to hold the + // ouptut. + TensorShape filter_shuffle_shape( + {out_depth, filter_rows, filter_cols, in_depth}); + Tensor filter_shuffle; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), + filter_shuffle_shape, + &filter_shuffle)); + + functor::SpatialConvolution()( + context->eigen_device(), filter_shuffle.tensor(), + padded_output_cref.tensor(), in_shuffle_cref.tensor(), 1, + BrainPadding2EigenPadding(VALID)); + + // Now copy the filter_backprop back to the destination. + Eigen::DSizes filter_order{1, 2, 3, 0}; + Eigen::array filter_rev_dims{true, true, false, false}; + const Tensor& filter_shuffle_cref = filter_shuffle; + functor::ShuffleAndReverse()( + context->eigen_device(), filter_shuffle_cref.tensor(), + filter_order, filter_rev_dims, filter_backprop->tensor()); + } + } + + private: + std::vector strides_; + Padding padding_; + bool use_cudnn_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropFilterOp); +}; + +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ShuffleAndReverse::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const Eigen::DSizes& order, \ + const Eigen::array& reverse_dims, \ + typename TTypes::Tensor output); \ + extern template struct ShuffleAndReverse; \ + template <> \ + void InflatePadAndShuffle::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const Eigen::DSizes& strides, \ + const Eigen::array, 4>& pad_dims, \ + const Eigen::DSizes& order, \ + typename TTypes::Tensor output); \ + extern template struct InflatePadAndShuffle; \ + template <> \ + void TransformFilter::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + typename TTypes::Tensor out); \ + extern template struct TransformFilter; \ + template <> \ + void TransformDepth::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + const Eigen::DSizes& shuffle, \ + typename TTypes::Tensor out); \ + extern template struct TransformDepth; \ + template <> \ + void SpatialConvolution::operator()( \ + const GPUDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input, \ + typename TTypes::ConstTensor filter, int stride, \ + const Eigen::PaddingType& padding); \ + extern template struct SpatialConvolution; \ + template <> \ + void SpatialConvolutionBackwardInput::operator()( \ + const GPUDevice& d, typename TTypes::Tensor in_backprop, \ + typename TTypes::ConstTensor filter, \ + typename TTypes::ConstTensor output_backprop, int input_rows, \ + int input_cols, int stride); \ + extern template struct SpatialConvolutionBackwardInput + +DECLARE_GPU_SPEC(float); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("input_sizes"), + Conv2DSlowBackpropInputOp); +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("filter_sizes"), + Conv2DSlowBackpropFilterOp); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc new file mode 100644 index 0000000000..aaa2951778 --- /dev/null +++ b/tensorflow/core/kernels/conv_ops.cc @@ -0,0 +1,373 @@ +// See docs in ../ops/nn_ops.cc. + +#define USE_EIGEN_TENSOR +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/util/use_cudnn.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/public/tensor.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/stream_executor/stream.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +struct LaunchGeneric { + static void launch(OpKernelContext* ctx, const Tensor& input, + const Tensor& filter, int stride, + const Eigen::PaddingType& padding, Tensor* output) { + if (filter.dim_size(1) == filter.dim_size(0) && filter.dim_size(0) == 1 && + stride == 1) { + // For 1x1 kernel, the 2D convolution is reduced to matrix + // multiplication. + // + // TODO(vrv): We should be able to call SpatialConvolution + // and it will produce the same result, but doing so + // led to NaNs during training. Using matmul instead for now. + int conv_width = 1; // Width for the convolution step. + for (int i = 0; i < 3; ++i) { + conv_width *= output->dim_size(i); + } + + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + functor::MatMulConvFunctor()( + ctx->eigen_device(), + output->shaped({conv_width, filter.dim_size(3)}), + input.shaped({conv_width, filter.dim_size(2)}), + filter.shaped({filter.dim_size(2), filter.dim_size(3)}), + dim_pair); + } else { + functor::SpatialConvolution()( + ctx->eigen_device(), output->tensor(), + input.tensor(), filter.tensor(), stride, padding); + } + } +}; + +template +struct LaunchConvOp; + +template +struct LaunchConvOp { + static void launch(OpKernelContext* ctx, bool use_cudnn, const Tensor& input, + const Tensor& filter, int stride, + const Eigen::PaddingType& padding, Tensor* output) { + LaunchGeneric::launch(ctx, input, filter, stride, padding, + output); + } +}; + +template +class Conv2DOp : public BinaryOp { + public: + explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); + use_cudnn_ &= CanUseCudnn(); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument( + "Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(context, strides_[1] == strides_[2], + errors::InvalidArgument( + "Current implementation only supports equal length " + "strides in the row and column dimensions.")); + OP_REQUIRES(context, (strides_[0] == 1 && strides_[3] == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + // Input tensor is of the following dimensions: + // [ batch, in_rows, in_cols, in_depth ] + + const Tensor& input = context->input(0); + + // Input filter is of the following dimensions: + // [ filter_rows, filter_cols, in_depth, out_depth] + const Tensor& filter = context->input(1); + + // For 2D convolution, there should be 4 dimensions. + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().ShortDebugString())); + OP_REQUIRES(context, filter.dims() == 4, + errors::InvalidArgument("filter must be 4-dimensional: ", + filter.shape().ShortDebugString())); + + // The last dimension for input is in_depth. It must be the same as the + // filter's in_depth. + const int64 in_depth = input.dim_size(3); + OP_REQUIRES( + context, in_depth == filter.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + in_depth, " vs ", filter.dim_size(2))); + + // The last dimension for filter is out_depth. + const int64 out_depth = filter.dim_size(3); + + // The second dimension for input is rows/height. + // The first dimension for filter is rows/height. + const int64 input_rows = input.dim_size(1); + const int64 filter_rows = filter.dim_size(0); + + // The third dimension for input is columns/width. + // The second dimension for filter is columns/width. + const int64 input_cols = input.dim_size(2); + const int64 filter_cols = filter.dim_size(1); + + // The first dimension for input is batch. + const int64 batch = input.dim_size(0); + + // For now we take the stride from the second dimension only (we + // assume row = col stride, and do not support striding on the + // batch or depth dimension). + const int stride = strides_[1]; + + int out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; + if (filter_cols == filter_rows && filter_rows == 1 && stride == 1) { + // For 1x1 kernel, the 2D convolution is reduced to matrix + // multiplication. + out_rows = input_rows; + out_cols = input_cols; + } else { + OP_REQUIRES_OK( + context, Get2dOutputSize(input_rows, input_cols, filter_rows, + filter_cols, stride, stride, padding_, + &out_rows, &out_cols, &pad_rows, &pad_cols)); + } + TensorShape out_shape({batch, out_rows, out_cols, out_depth}); + + // Output tensor is of the following dimensions: + // [ in_batch, out_rows, out_cols, out_depth ] + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + VLOG(2) << "Conv2D: in_depth = " << in_depth + << ", input_cols = " << input_cols + << ", filter_cols = " << filter_cols + << ", input_rows = " << input_rows + << ", filter_rows = " << filter_rows << ", stride = " << stride + << ", out_depth = " << out_depth; + + LaunchConvOp::launch(context, use_cudnn_, input, filter, stride, + BrainPadding2EigenPadding(padding_), + output); + } + + private: + std::vector strides_; + bool use_cudnn_; + Padding padding_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp); +}; + +REGISTER_KERNEL_BUILDER(Name("Conv2D") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + Conv2DOp); + +#if GOOGLE_CUDA + +namespace { +template +perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory, + uint64 size) { + perftools::gputools::DeviceMemoryBase wrapped(const_cast(cuda_memory), + size * sizeof(T)); + perftools::gputools::DeviceMemory typed(wrapped); + return typed; +} +} // namespace + +template +struct LaunchConvOp { + static void launch(OpKernelContext* ctx, bool use_cudnn, + const Tensor& input_param, const Tensor& filter, + int stride, const Eigen::PaddingType& padding, + Tensor* output) { + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + + if (use_cudnn) { + Tensor input = input_param; + if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1) { + // 1x1 filter, so call cublas directly. + const uint64 m = + input.dim_size(0) * input.dim_size(1) * input.dim_size(2); + const uint64 k = filter.dim_size(2); + const uint64 n = filter.dim_size(3); + + auto a_ptr = AsDeviceMemory(input.template flat().data(), + input.template flat().size()); + auto b_ptr = AsDeviceMemory(filter.template flat().data(), + filter.template flat().size()); + auto c_ptr = AsDeviceMemory(output->template flat().data(), + output->template flat().size()); + + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + bool blas_launch_status = + stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, + b_ptr, n, a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } + return; + } + if (padding == Eigen::PADDING_SAME) { + const int64 out_rows = output->dim_size(1); + const int64 out_cols = output->dim_size(2); + const int64 in_rows = input.dim_size(1); + const int64 in_cols = input.dim_size(2); + const int64 patch_rows = filter.dim_size(0); + const int64 patch_cols = filter.dim_size(1); + // Total padding on rows and cols is + // Pr = (R' - 1) * S + Kr - R + // Pc = (C' - 1) * S + Kc - C + // where (R', C') are output dimensions, (R, C) are input dimensions, S + // is stride, (Kr, Kc) are filter dimensions. + // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top + // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means + // we pad more on the right and bottom than on the top and left. + const int padding_rows = (out_rows - 1) * stride + patch_rows - in_rows; + const int padding_cols = (out_cols - 1) * stride + patch_cols - in_cols; + Tensor transformed_input; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DataTypeToEnum::value, + TensorShape( + {input.dim_size(0), input.dim_size(1) + padding_rows, + input.dim_size(2) + padding_cols, input.dim_size(3)}), + &transformed_input)); + + functor::PadInput()( + ctx->eigen_device(), input_param.tensor(), + padding_rows / 2, padding_rows - padding_rows / 2, padding_cols / 2, + padding_cols - padding_cols / 2, transformed_input.tensor()); + input = transformed_input; + } + + perftools::gputools::dnn::BatchDescriptor input_desc; + input_desc.set_count(input.dim_size(0)) + .set_height(input.dim_size(1)) + .set_width(input.dim_size(2)) + .set_feature_map_count(input.dim_size(3)) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth); + perftools::gputools::dnn::BatchDescriptor output_desc; + output_desc.set_count(output->dim_size(0)) + .set_height(output->dim_size(1)) + .set_width(output->dim_size(2)) + .set_feature_map_count(output->dim_size(3)) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth); + perftools::gputools::dnn::FilterDescriptor filter_desc; + filter_desc.set_input_filter_height(filter.dim_size(0)) + .set_input_filter_width(filter.dim_size(1)) + .set_input_feature_map_count(filter.dim_size(2)) + .set_output_feature_map_count(filter.dim_size(3)); + perftools::gputools::dnn::ConvolutionDescriptor conv_desc; + conv_desc.set_vertical_filter_stride(stride) + .set_horizontal_filter_stride(stride); + + Tensor transformed_filter; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp( + DataTypeToEnum::value, + TensorShape({filter.dim_size(3), filter.dim_size(2), + filter.dim_size(0), filter.dim_size(1)}), + &transformed_filter)); + + functor::TransformFilter()( + ctx->eigen_device(), filter.tensor(), + transformed_filter.tensor()); + + auto input_ptr = AsDeviceMemory(input.template flat().data(), + input.template flat().size()); + auto filter_ptr = + AsDeviceMemory(transformed_filter.template flat().data(), + transformed_filter.template flat().size()); + auto output_ptr = AsDeviceMemory(output->template flat().data(), + output->template flat().size()); + + bool cudnn_launch_status = + stream->ThenConvolve(input_desc, input_ptr, filter_desc, filter_ptr, + conv_desc, output_desc, &output_ptr) + .ok(); + + if (!cudnn_launch_status) { + ctx->SetStatus(errors::Internal( + "cuDNN launch failure : input shape(", input.shape().DebugString(), + ") filter shape(", filter.shape().DebugString(), ")")); + } + } else { + LaunchGeneric::launch(ctx, input_param, filter, stride, + padding, output); + } + } +}; + +#endif // GOOGLE_CUDA + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void SpatialConvolution::operator()( \ + const GPUDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input, \ + typename TTypes::ConstTensor filter, int stride, \ + const Eigen::PaddingType& padding); \ + extern template struct SpatialConvolution; \ + template <> \ + void MatMulConvFunctor::operator()( \ + const GPUDevice& d, typename TTypes::Tensor out, \ + typename TTypes::ConstTensor in0, \ + typename TTypes::ConstTensor in1, \ + const Eigen::array, 1>& dim_pair); \ + extern template struct MatMulConvFunctor; \ + template <> \ + void TransformFilter::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + typename TTypes::Tensor out); \ + extern template struct TransformFilter; \ + template <> \ + void PadInput::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + int padding_rows_left, int padding_rows_right, int padding_cols_left, \ + int padding_cols_right, typename TTypes::Tensor out); \ + extern template struct PadInput + +DECLARE_GPU_SPEC(float); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +REGISTER_KERNEL_BUILDER(Name("Conv2D") + .Device(DEVICE_GPU) + .TypeConstraint("T"), + Conv2DOp); + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops_gpu.cu.cc b/tensorflow/core/kernels/conv_ops_gpu.cu.cc new file mode 100644 index 0000000000..44af814e2b --- /dev/null +++ b/tensorflow/core/kernels/conv_ops_gpu.cu.cc @@ -0,0 +1,35 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/conv_2d.h" + +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +template +struct SpatialConvolution { + void operator()(const GPUDevice& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, int stride, + const Eigen::PaddingType& padding) { + // TODO(keveman): nvcc 6.5 crashes when 32 bit indexing is turned on. Enable + // this when we move to cuda 7.0. + // SpatialConvolutionFunc(d, To32Bit(output), To32Bit(input), + // To32Bit(filter), stride, padding); + + SpatialConvolutionFunc(d, output, input, filter, stride, padding); + } +}; + +template struct SpatialConvolution; + +} // end namespace functor +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc new file mode 100644 index 0000000000..e2e9d25d83 --- /dev/null +++ b/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc @@ -0,0 +1,16 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/conv_2d.h" + +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; +template struct functor::InflatePadAndShuffle; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc new file mode 100644 index 0000000000..dbbe08ef9c --- /dev/null +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -0,0 +1,22 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/conv_2d.h" + +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; +template struct functor::ShuffleAndReverse; + +template struct functor::TransformFilter; + +template struct functor::PadInput; + +template struct functor::TransformDepth; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/conv_ops_gpu_matmul.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_matmul.cu.cc new file mode 100644 index 0000000000..87d79ecb4d --- /dev/null +++ b/tensorflow/core/kernels/conv_ops_gpu_matmul.cu.cc @@ -0,0 +1,16 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/conv_2d.h" + +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; +template struct functor::MatMulConvFunctor; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/core_ops_test.cc b/tensorflow/core/kernels/core_ops_test.cc new file mode 100644 index 0000000000..a42a5999da --- /dev/null +++ b/tensorflow/core/kernels/core_ops_test.cc @@ -0,0 +1,990 @@ +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include +#include +#include +#include + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/port.h" +#include + +namespace tensorflow { + +static void SetConstOp(const string& name, std::initializer_list dims, + NodeDef* node) { + Tensor tensor(DT_FLOAT, TensorShape(dims)); + for (int64 i = 0; i < tensor.NumElements(); ++i) { + tensor.flat()(i) = i / 10.0f; + } + TF_CHECK_OK(NodeDefBuilder(name, "Const") + .Attr("dtype", DT_FLOAT) + .Attr("value", tensor) + .Finalize(node)); +} + +static void SetConstSizesOp(const string& name, const std::vector& sizes, + NodeDef* node) { + TensorShape shape; + shape.AddDim(sizes.size()); + Tensor tensor(DT_INT32, shape); + for (int64 i = 0; i < tensor.NumElements(); ++i) { + tensor.flat()(i) = sizes[i]; + } + TF_CHECK_OK(NodeDefBuilder(name, "Const") + .Attr("dtype", DT_INT32) + .Attr("value", tensor) + .Finalize(node)); +} + +namespace { + +enum CONV_OP { + CONV_OP_FORWARD = 0, + CONV_OP_BACKPROP_INPUT = 1, + CONV_OP_BACKPROP_FILTER = 2 +}; + +} // namespace + +static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth, + int out_depth, int filter_rows, int filter_cols, + CONV_OP op, int num_threads, int stride, + Padding padding, bool use_gpu, const string& label) { + if (!IsGoogleCudaEnabled() && use_gpu) { + testing::SetLabel( + strings::StrCat("Skipping GPU test (no --config=cuda): ", label)); + return; + } + testing::SetLabel(label); + + // Set the number of threads + SessionOptions options; + options.config.set_intra_op_parallelism_threads(num_threads); + + // We set up a graph for computing convolution. + GraphDef graph; + + // For this, we need an input tensor and a filter tensor. + // Compute the output size. + int out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; + TF_CHECK_OK(Get2dOutputSize(rows, cols, filter_rows, filter_cols, stride, + stride, padding, &out_rows, &out_cols, &pad_rows, + &pad_cols)); + // Counting the number of floating point operations (both MUL and ADD) + int64 num_ops = 0; + if (op == CONV_OP_FORWARD) { + // Forward computation: + // BATCH x OUT_ROW X OUT_COL X IN_DEPTH X PATCH_ROW X PATH_COL X OUT_DEPTH + // We multiply by two since there are mutliplications and additions. + num_ops = static_cast(batch * in_depth * out_depth) * + static_cast(filter_rows * filter_cols) * + static_cast(out_rows * out_cols) * 2; + } else { + // Backward computation: both input and filter backprop take the same + // amount of computation: + // BATCH x IN_ROW X IN_COL X IN_DEPTH X PATCH_ROW X PATCH_COL X OUT_DEPTH + // We multiply by two since there are mutliplications and additions. + num_ops = static_cast(batch * in_depth * out_depth) * + static_cast(filter_rows * filter_cols) * + static_cast(rows * cols) * 2; + } + + SetConstOp("input", {batch, rows, cols, in_depth}, graph.add_node()); + SetConstOp("filter", {filter_rows, filter_cols, in_depth, out_depth}, + graph.add_node()); + SetConstOp("output_backprop", {batch, out_rows, out_cols, out_depth}, + graph.add_node()); + SetConstSizesOp("input_sizes", + std::vector({batch, rows, cols, in_depth}), + graph.add_node()); + SetConstSizesOp("filter_sizes", std::vector({filter_rows, filter_cols, + in_depth, out_depth}), + graph.add_node()); + + // Now add the convolution op + NodeDef* conv = graph.add_node(); + switch (op) { + case CONV_OP_FORWARD: + TF_CHECK_OK(NodeDefBuilder("conv2d", "Conv2D") + .Input("input", 0, DT_FLOAT) + .Input("filter", 0, DT_FLOAT) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", padding == VALID ? "VALID" : "SAME") + .Finalize(conv)); + break; + case CONV_OP_BACKPROP_INPUT: + TF_CHECK_OK(NodeDefBuilder("conv2d", "Conv2DBackpropInput") + .Input("input_sizes", 0, DT_INT32) + .Input("filter", 0, DT_FLOAT) + .Input("output_backprop", 0, DT_FLOAT) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", padding == VALID ? "VALID" : "SAME") + .Finalize(conv)); + break; + case CONV_OP_BACKPROP_FILTER: + TF_CHECK_OK(NodeDefBuilder("conv2d", "Conv2DBackpropFilter") + .Input("input", 0, DT_FLOAT) + .Input("filter_sizes", 0, DT_INT32) + .Input("output_backprop", 0, DT_FLOAT) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", padding == VALID ? "VALID" : "SAME") + .Finalize(conv)); + break; + } + Graph* g = new Graph(OpRegistry::Global()); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph, g)); + + string device = use_gpu ? "gpu" : "cpu"; + test::Benchmark(device, g, &options).Run(iters); + testing::ItemsProcessed(num_ops * iters); +} + +// BS: batch_size +// R: tensor_in_rows +// C: tensor_in_cols +// ID: input_depth +// OD: output_depth +// KR: kernel_rows +// KC: kernel_cols +#define BM_ConvFloatFwd(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL) \ + static void BM_ConvFloatFwdCPU1_##LABEL(int iters) { \ + BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \ + PAD, false, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \ + } \ + static void BM_ConvFloatFwdCPU4_##LABEL(int iters) { \ + BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 4, STR, \ + PAD, false, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \ + } \ + static void BM_ConvFloatFwdGPU_##LABEL(int iters) { \ + BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \ + PAD, true, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ + } \ + BENCHMARK(BM_ConvFloatFwdCPU1_##LABEL); \ + BENCHMARK(BM_ConvFloatFwdCPU4_##LABEL); \ + BENCHMARK(BM_ConvFloatFwdGPU_##LABEL) + +BM_ConvFloatFwd(32, 5, 5, 1248, 128, 1, 1, 1, SAME, conv0); +BM_ConvFloatFwd(32, 8, 8, 384, 384, 1, 3, 1, SAME, conv1); +BM_ConvFloatFwd(32, 8, 8, 384, 384, 3, 1, 1, SAME, conv2); +BM_ConvFloatFwd(32, 8, 8, 2048, 192, 1, 1, 1, SAME, conv3); +BM_ConvFloatFwd(32, 8, 8, 448, 384, 3, 3, 1, SAME, conv4); +BM_ConvFloatFwd(32, 8, 8, 2048, 320, 1, 1, 1, SAME, conv5); +BM_ConvFloatFwd(32, 8, 8, 2048, 448, 1, 1, 1, SAME, conv6); +BM_ConvFloatFwd(32, 8, 8, 2048, 384, 1, 1, 1, SAME, conv7); +BM_ConvFloatFwd(32, 8, 8, 1760, 384, 1, 1, 1, SAME, conv8); +BM_ConvFloatFwd(32, 8, 8, 1760, 192, 1, 1, 1, SAME, conv9); +BM_ConvFloatFwd(32, 8, 8, 1760, 448, 1, 1, 1, SAME, conv10); +BM_ConvFloatFwd(32, 8, 8, 1760, 320, 1, 1, 1, SAME, conv11); +BM_ConvFloatFwd(32, 17, 17, 192, 192, 3, 3, 2, VALID, conv12); +BM_ConvFloatFwd(32, 17, 17, 192, 192, 3, 3, 1, SAME, conv13); +BM_ConvFloatFwd(32, 17, 17, 1248, 192, 1, 1, 1, SAME, conv14); +BM_ConvFloatFwd(32, 17, 17, 128, 320, 3, 3, 2, VALID, conv15); +BM_ConvFloatFwd(32, 17, 17, 1248, 128, 1, 1, 1, SAME, conv16); +BM_ConvFloatFwd(32, 17, 17, 224, 224, 1, 3, 1, SAME, conv17); +BM_ConvFloatFwd(32, 17, 17, 192, 256, 3, 1, 1, SAME, conv18); +BM_ConvFloatFwd(32, 17, 17, 192, 256, 1, 3, 1, SAME, conv19); +BM_ConvFloatFwd(32, 17, 17, 1216, 192, 1, 1, 1, SAME, conv20); +BM_ConvFloatFwd(32, 17, 17, 1216, 96, 1, 1, 1, SAME, conv21); +BM_ConvFloatFwd(32, 17, 17, 224, 224, 3, 1, 1, SAME, conv22); +BM_ConvFloatFwd(32, 17, 17, 192, 224, 3, 3, 1, SAME, conv23); +BM_ConvFloatFwd(32, 17, 17, 192, 192, 1, 3, 1, SAME, conv24); +BM_ConvFloatFwd(32, 17, 17, 1152, 192, 1, 1, 1, SAME, conv25); +BM_ConvFloatFwd(32, 17, 17, 1152, 128, 1, 1, 1, SAME, conv26); +BM_ConvFloatFwd(32, 17, 17, 192, 192, 3, 1, 1, SAME, conv27); +BM_ConvFloatFwd(32, 17, 17, 160, 192, 3, 3, 1, SAME, conv28); +BM_ConvFloatFwd(32, 17, 17, 1152, 160, 1, 1, 1, SAME, conv29); +BM_ConvFloatFwd(32, 17, 17, 1024, 128, 1, 1, 1, SAME, conv30); +BM_ConvFloatFwd(32, 17, 17, 128, 192, 1, 3, 1, SAME, conv31); +BM_ConvFloatFwd(32, 17, 17, 1024, 160, 1, 1, 1, SAME, conv32); +BM_ConvFloatFwd(32, 17, 17, 128, 192, 3, 1, 1, SAME, conv33); +BM_ConvFloatFwd(32, 17, 17, 1024, 256, 1, 1, 1, SAME, conv34); +BM_ConvFloatFwd(32, 17, 17, 128, 128, 3, 1, 1, SAME, conv35); +BM_ConvFloatFwd(32, 17, 17, 768, 192, 1, 1, 1, SAME, conv36); +BM_ConvFloatFwd(32, 17, 17, 128, 128, 1, 3, 1, SAME, conv37); +BM_ConvFloatFwd(32, 17, 17, 128, 128, 3, 3, 1, SAME, conv38); +BM_ConvFloatFwd(32, 17, 17, 768, 128, 1, 1, 1, SAME, conv39); +BM_ConvFloatFwd(32, 17, 17, 768, 320, 1, 1, 1, SAME, conv40); +BM_ConvFloatFwd(32, 35, 35, 96, 96, 3, 3, 2, VALID, conv41); +BM_ConvFloatFwd(32, 35, 35, 288, 384, 3, 3, 2, VALID, conv42); +BM_ConvFloatFwd(32, 35, 35, 64, 96, 3, 3, 1, SAME, conv43); +BM_ConvFloatFwd(32, 35, 35, 288, 64, 1, 1, 1, SAME, conv44); +BM_ConvFloatFwd(32, 35, 35, 256, 64, 1, 1, 1, SAME, conv45); +BM_ConvFloatFwd(32, 35, 35, 48, 64, 5, 5, 1, SAME, conv46); +BM_ConvFloatFwd(32, 35, 35, 256, 48, 1, 1, 1, SAME, conv47); +BM_ConvFloatFwd(32, 35, 35, 96, 96, 3, 3, 1, SAME, conv48); +BM_ConvFloatFwd(32, 35, 35, 192, 32, 1, 1, 1, SAME, conv49); +BM_ConvFloatFwd(32, 35, 35, 192, 64, 1, 1, 1, SAME, conv50); +BM_ConvFloatFwd(32, 35, 35, 192, 48, 1, 1, 1, SAME, conv51); +BM_ConvFloatFwd(32, 73, 73, 64, 192, 3, 3, 1, VALID, conv52); +BM_ConvFloatFwd(32, 73, 73, 64, 64, 1, 1, 1, VALID, conv53); +BM_ConvFloatFwd(32, 147, 147, 24, 64, 1, 1, 1, VALID, conv54); + +#define BM_ConvFloatBkInAndFilter(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL) \ + static void BM_ConvFloatBkInCPU1_##LABEL(int iters) { \ + BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \ + STR, PAD, false, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \ + } \ + static void BM_ConvFloatBkInCPU4_##LABEL(int iters) { \ + BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 4, \ + STR, PAD, false, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \ + } \ + static void BM_ConvFloatBkInGPU_##LABEL(int iters) { \ + BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \ + STR, PAD, true, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ + } \ + static void BM_ConvFloatBkFilterCPU1_##LABEL(int iters) { \ + BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ + STR, PAD, false, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \ + } \ + static void BM_ConvFloatBkFilterCPU4_##LABEL(int iters) { \ + BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 4, \ + STR, PAD, false, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \ + } \ + static void BM_ConvFloatBkFilterGPU_##LABEL(int iters) { \ + BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ + STR, PAD, true, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ + } \ + BENCHMARK(BM_ConvFloatBkInCPU1_##LABEL); \ + BENCHMARK(BM_ConvFloatBkInCPU4_##LABEL); \ + BENCHMARK(BM_ConvFloatBkInGPU_##LABEL); \ + BENCHMARK(BM_ConvFloatBkFilterCPU1_##LABEL); \ + BENCHMARK(BM_ConvFloatBkFilterCPU4_##LABEL); \ + BENCHMARK(BM_ConvFloatBkFilterGPU_##LABEL) + +// Benchmarks from the inception model + +BM_ConvFloatBkInAndFilter(32, 5, 5, 1248, 128, 1, 1, 1, SAME, conv0); +BM_ConvFloatBkInAndFilter(32, 8, 8, 384, 384, 1, 3, 1, SAME, conv1); +BM_ConvFloatBkInAndFilter(32, 8, 8, 384, 384, 3, 1, 1, SAME, conv2); +BM_ConvFloatBkInAndFilter(32, 8, 8, 2048, 192, 1, 1, 1, SAME, conv3); +BM_ConvFloatBkInAndFilter(32, 8, 8, 448, 384, 3, 3, 1, SAME, conv4); +BM_ConvFloatBkInAndFilter(32, 8, 8, 2048, 320, 1, 1, 1, SAME, conv5); +BM_ConvFloatBkInAndFilter(32, 8, 8, 2048, 448, 1, 1, 1, SAME, conv6); +BM_ConvFloatBkInAndFilter(32, 8, 8, 2048, 384, 1, 1, 1, SAME, conv7); +BM_ConvFloatBkInAndFilter(32, 8, 8, 1760, 384, 1, 1, 1, SAME, conv8); +BM_ConvFloatBkInAndFilter(32, 8, 8, 1760, 192, 1, 1, 1, SAME, conv9); +BM_ConvFloatBkInAndFilter(32, 8, 8, 1760, 448, 1, 1, 1, SAME, conv10); +BM_ConvFloatBkInAndFilter(32, 8, 8, 1760, 320, 1, 1, 1, SAME, conv11); +BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 192, 3, 3, 2, VALID, conv12); +BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 192, 3, 3, 1, SAME, conv13); +BM_ConvFloatBkInAndFilter(32, 17, 17, 1248, 192, 1, 1, 1, SAME, conv14); +BM_ConvFloatBkInAndFilter(32, 17, 17, 128, 320, 3, 3, 2, VALID, conv15); +BM_ConvFloatBkInAndFilter(32, 17, 17, 1248, 128, 1, 1, 1, SAME, conv16); +BM_ConvFloatBkInAndFilter(32, 17, 17, 224, 224, 1, 3, 1, SAME, conv17); +BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 256, 3, 1, 1, SAME, conv18); +BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 256, 1, 3, 1, SAME, conv19); +BM_ConvFloatBkInAndFilter(32, 17, 17, 1216, 192, 1, 1, 1, SAME, conv20); +BM_ConvFloatBkInAndFilter(32, 17, 17, 1216, 96, 1, 1, 1, SAME, conv21); +BM_ConvFloatBkInAndFilter(32, 17, 17, 224, 224, 3, 1, 1, SAME, conv22); +BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 224, 3, 3, 1, SAME, conv23); +BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 192, 1, 3, 1, SAME, conv24); +BM_ConvFloatBkInAndFilter(32, 17, 17, 1152, 192, 1, 1, 1, SAME, conv25); +BM_ConvFloatBkInAndFilter(32, 17, 17, 1152, 128, 1, 1, 1, SAME, conv26); +BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 192, 3, 1, 1, SAME, conv27); +BM_ConvFloatBkInAndFilter(32, 17, 17, 160, 192, 3, 3, 1, SAME, conv28); +BM_ConvFloatBkInAndFilter(32, 17, 17, 1152, 160, 1, 1, 1, SAME, conv29); +BM_ConvFloatBkInAndFilter(32, 17, 17, 1024, 128, 1, 1, 1, SAME, conv30); +BM_ConvFloatBkInAndFilter(32, 17, 17, 128, 192, 1, 3, 1, SAME, conv31); +BM_ConvFloatBkInAndFilter(32, 17, 17, 1024, 160, 1, 1, 1, SAME, conv32); +BM_ConvFloatBkInAndFilter(32, 17, 17, 128, 192, 3, 1, 1, SAME, conv33); +BM_ConvFloatBkInAndFilter(32, 17, 17, 1024, 256, 1, 1, 1, SAME, conv34); +BM_ConvFloatBkInAndFilter(32, 17, 17, 128, 128, 3, 1, 1, SAME, conv35); +BM_ConvFloatBkInAndFilter(32, 17, 17, 768, 192, 1, 1, 1, SAME, conv36); +BM_ConvFloatBkInAndFilter(32, 17, 17, 128, 128, 1, 3, 1, SAME, conv37); +BM_ConvFloatBkInAndFilter(32, 17, 17, 128, 128, 3, 3, 1, SAME, conv38); +BM_ConvFloatBkInAndFilter(32, 17, 17, 768, 128, 1, 1, 1, SAME, conv39); +BM_ConvFloatBkInAndFilter(32, 17, 17, 768, 320, 1, 1, 1, SAME, conv40); +BM_ConvFloatBkInAndFilter(32, 35, 35, 96, 96, 3, 3, 2, VALID, conv41); +BM_ConvFloatBkInAndFilter(32, 35, 35, 288, 384, 3, 3, 2, VALID, conv42); +BM_ConvFloatBkInAndFilter(32, 35, 35, 64, 96, 3, 3, 1, SAME, conv43); +BM_ConvFloatBkInAndFilter(32, 35, 35, 288, 64, 1, 1, 1, SAME, conv44); +BM_ConvFloatBkInAndFilter(32, 35, 35, 256, 64, 1, 1, 1, SAME, conv45); +BM_ConvFloatBkInAndFilter(32, 35, 35, 48, 64, 5, 5, 1, SAME, conv46); +BM_ConvFloatBkInAndFilter(32, 35, 35, 256, 48, 1, 1, 1, SAME, conv47); +BM_ConvFloatBkInAndFilter(32, 35, 35, 96, 96, 3, 3, 1, SAME, conv48); +BM_ConvFloatBkInAndFilter(32, 35, 35, 192, 32, 1, 1, 1, SAME, conv49); +BM_ConvFloatBkInAndFilter(32, 35, 35, 192, 64, 1, 1, 1, SAME, conv50); +BM_ConvFloatBkInAndFilter(32, 35, 35, 192, 48, 1, 1, 1, SAME, conv51); +BM_ConvFloatBkInAndFilter(32, 73, 73, 64, 192, 3, 3, 1, VALID, conv52); +BM_ConvFloatBkInAndFilter(32, 73, 73, 64, 64, 1, 1, 1, VALID, conv53); +BM_ConvFloatBkInAndFilter(32, 147, 147, 24, 64, 1, 1, 1, VALID, conv54); + +#define BM_ConvFloatBkFCPU(BS, R, C, ID, OD, KR, KC, TH, LABEL) \ + static void \ + BM_ConvFloatBkFCPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC##_##TH( \ + int iters) { \ + BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, TH, \ + 1, VALID, false, LABEL); \ + } \ + BENCHMARK( \ + BM_ConvFloatBkFCPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC##_##TH) + +// Benchmarks from https://github.com/soumith/convnet-benchmarks +BM_ConvFloatBkFCPU(128, 128, 128, 3, 96, 11, 11, 4, "convnet-layer1"); +BM_ConvFloatBkFCPU(128, 64, 64, 64, 128, 9, 9, 4, "convnet-layer2"); +BM_ConvFloatBkFCPU(128, 32, 32, 128, 128, 9, 9, 4, "convnet-layer3"); +BM_ConvFloatBkFCPU(128, 16, 16, 128, 128, 7, 7, 4, "convnet-layer4"); +BM_ConvFloatBkFCPU(128, 13, 13, 384, 384, 3, 3, 4, "convnet-layer5"); + +#define BM_ConvFloatBkFGPU(BS, R, C, ID, OD, KR, KC, LABEL) \ + static void BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC( \ + int iters) { \ + BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ + 1, VALID, true, LABEL); \ + } \ + BENCHMARK(BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC) + +// Benchmarks from https://github.com/soumith/convnet-benchmarks +BM_ConvFloatBkFGPU(128, 128, 128, 3, 96, 11, 11, "convnet-layer1"); +BM_ConvFloatBkFGPU(128, 64, 64, 64, 128, 9, 9, "convnet-layer2"); +BM_ConvFloatBkFGPU(128, 32, 32, 128, 128, 9, 9, "convnet-layer3"); +BM_ConvFloatBkFGPU(128, 16, 16, 128, 128, 7, 7, "convnet-layer4"); +BM_ConvFloatBkFGPU(128, 13, 13, 384, 384, 3, 3, "convnet-layer5"); + +static void BM_LRNFloat(int iters, int depth, int cols, int rows, + int batch_size, int range, int num_threads, + const string& label) { + tensorflow::testing::StopTiming(); + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + + thread::ThreadPool threadpool(Env::Default(), "test", num_threads); + EigenThreadPoolWrapper wrapper(&threadpool); + Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads); + device->set_eigen_cpu_device(&eigen_cpu_device); + + gtl::InlinedVector inputs; + TensorShape shape({batch_size, rows, cols, depth}); + + Tensor input(DT_FLOAT, shape); + test::FillIota(&input, 1.0); + inputs.push_back({nullptr, &input}); + + // Convolution op. + NodeDef lrn_node_def; + TF_CHECK_OK(NodeDefBuilder("lrn_op", "LRN") + .Input("input", 0, DT_FLOAT) + .Attr("depth_radius", range) + .Attr("bias", 1.0) + .Attr("alpha", 0.1) + .Attr("beta", 0.5) + .Finalize(&lrn_node_def)); + + Status status; + std::unique_ptr op(CreateOpKernel( + DEVICE_CPU, device.get(), cpu_allocator(), lrn_node_def, &status)); + TF_CHECK_OK(status); + + OpKernelContext::Params params; + params.device = device.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs; + params.op_kernel = op.get(); + params.output_alloc_attr = [&device, &op, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + + std::unique_ptr context(new OpKernelContext(params)); + + op->Compute(context.get()); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + delete context->release_output(0).tensor; + op->Compute(context.get()); + } + tensorflow::testing::StopTiming(); + testing::ItemsProcessed(context->mutable_output(0)->NumElements() * iters * + (2 * range + 1) * 2); + testing::SetLabel(label); +} + +#define BM_LRNFloatFwdCPU(DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL) \ + static void \ + BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS( \ + int iters) { \ + BM_LRNFloat(iters, DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL); \ + } \ + BENCHMARK( \ + BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS) + +// clang-format off +// DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL +BM_LRNFloatFwdCPU(64, 56, 56, 32, 5, 1, "lrn 1 thread"); +BM_LRNFloatFwdCPU(192, 28, 28, 64, 2, 1, "lrn 1 thread"); +BM_LRNFloatFwdCPU(192, 56, 56, 32, 5, 1, "lrn 1 thread"); +BM_LRNFloatFwdCPU(64, 56, 56, 32, 5, 4, "lrn 4 threads"); +BM_LRNFloatFwdCPU(192, 28, 28, 64, 2, 4, "lrn 4 threads"); +BM_LRNFloatFwdCPU(192, 56, 56, 32, 5, 4, "lrn 4 threads"); +BM_LRNFloatFwdCPU(64, 56, 56, 32, 5, 8, "lrn 8 threads"); +BM_LRNFloatFwdCPU(192, 28, 28, 64, 2, 8, "lrn 8 threads"); +BM_LRNFloatFwdCPU(192, 56, 56, 32, 5, 8, "lrn 8 threads"); +// clang-format on + +/* +AvgPooling Op +*/ +static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth, + int kernel_rows, int kernel_cols, int stride, + Padding padding, int num_threads, const string& label) { + tensorflow::testing::StopTiming(); + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + + thread::ThreadPool threadpool(Env::Default(), "test", num_threads); + EigenThreadPoolWrapper wrapper(&threadpool); + Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads); + device->set_eigen_cpu_device(&eigen_cpu_device); + + gtl::InlinedVector inputs; + TensorShape shape1({batch_size, rows, cols, depth}); + Tensor input1(DT_FLOAT, shape1); + test::FillIota(&input1, 1.0); + inputs.push_back({nullptr, &input1}); + + // AvgPooling op. + NodeDef avgpool_node_def; + CHECK_EQ(kernel_rows, kernel_cols); + Status status = NodeDefBuilder("avgpool_op", "AvgPool") + .Input(FakeInput(DT_FLOAT)) + .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", padding == VALID ? "VALID" : "SAME") + .Finalize(&avgpool_node_def); + TF_CHECK_OK(status); + + std::unique_ptr op(CreateOpKernel( + DEVICE_CPU, device.get(), cpu_allocator(), avgpool_node_def, &status)); + TF_CHECK_OK(status); + OpKernelContext::Params params; + params.device = device.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs; + params.op_kernel = op.get(); + params.output_alloc_attr = [&device, &op, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + + std::unique_ptr avgpool_context(new OpKernelContext(params)); + + op->Compute(avgpool_context.get()); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + delete avgpool_context->release_output(0).tensor; + op->Compute(avgpool_context.get()); + } + tensorflow::testing::StopTiming(); + testing::ItemsProcessed(avgpool_context->mutable_output(0)->NumElements() * + iters); + testing::SetLabel(label); +} + +// BS: batch_size +// IR: input_rows +// IC: input_cols +// ND: node_depth +// KR: kernel_rows +// KC: kernel_cols +// ST: stride. We use the same stride for both directions. +// PT: padding +#define BM_AvgPoolFwdCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \ + static void \ + BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \ + int iters) { \ + BM_AvgPool(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \ + } \ + BENCHMARK( \ + BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) + +// Labels are taken from the 2014-July-24 version of imagenet +BM_AvgPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, VALID, 1, "avgpool0_VALID"); +BM_AvgPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, VALID, 1, "avgpool1_VALID"); +BM_AvgPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, VALID, 1, "avgpool4_VALID"); +BM_AvgPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, VALID, 1, "avgpool10_VALID"); +BM_AvgPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, SAME, 1, "avgpool0_SAME"); +BM_AvgPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 1, "avgpool1_SAME"); +BM_AvgPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 1, "avgpool4_SAME"); +BM_AvgPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 1, "avgpool10_SAME"); +BM_AvgPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, VALID, 4, "avgpool0_VALID"); +BM_AvgPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, VALID, 4, "avgpool1_VALID"); +BM_AvgPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, VALID, 4, "avgpool4_VALID"); +BM_AvgPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, VALID, 4, "avgpool10_VALID"); +BM_AvgPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, SAME, 4, "avgpool0_SAME"); +BM_AvgPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 4, "avgpool1_SAME"); +BM_AvgPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 4, "avgpool4_SAME"); +BM_AvgPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 4, "avgpool10_SAME"); + +static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols, + int depth, int kernel_rows, int kernel_cols, + int stride, Padding padding, int num_threads, + const string& label) { + tensorflow::testing::StopTiming(); + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + + thread::ThreadPool threadpool(Env::Default(), "test", num_threads); + EigenThreadPoolWrapper wrapper(&threadpool); + Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads); + device->set_eigen_cpu_device(&eigen_cpu_device); + + gtl::InlinedVector inputs; + + int out_height, out_width, pad_rows, pad_cols; + Status status = + Get2dOutputSize(rows, cols, kernel_rows, kernel_cols, stride, stride, + padding, &out_height, &out_width, &pad_rows, &pad_cols); + TF_CHECK_OK(status); + TensorShape output_shape({batch_size, out_height, out_width, depth}); + TensorShape shape2({4}); + Tensor input_shape_tensor(DT_INT32, shape2); + int32 input_dims[] = {batch_size, rows, cols, depth}; + for (int i = 0; i < 4; i++) { + input_shape_tensor.flat()(i) = input_dims[i]; + } + inputs.push_back({nullptr, &input_shape_tensor}); + + Tensor output_backprop(DT_FLOAT, output_shape); + test::FillIota(&output_backprop, 11.0); + inputs.push_back({nullptr, &output_backprop}); + + // AvgPoolGrad op. + NodeDef avgpool_grad_node_def; + status = NodeDefBuilder("avgpool_grad_op", "AvgPoolGrad") + .Input(FakeInput()) + .Input(FakeInput(DT_FLOAT)) + .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", padding == VALID ? "VALID" : "SAME") + .Finalize(&avgpool_grad_node_def); + TF_CHECK_OK(status); + std::unique_ptr op(CreateOpKernel( + DEVICE_CPU, nullptr, cpu_allocator(), avgpool_grad_node_def, &status)); + TF_CHECK_OK(status); + OpKernelContext::Params params; + params.device = device.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs; + params.op_kernel = op.get(); + params.output_alloc_attr = [&device, &op, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + + std::unique_ptr avgpool_context(new OpKernelContext(params)); + + op->Compute(avgpool_context.get()); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + delete avgpool_context->release_output(0).tensor; + op->Compute(avgpool_context.get()); + } + tensorflow::testing::StopTiming(); + testing::ItemsProcessed(avgpool_context->mutable_output(0)->NumElements() * + iters); + testing::SetLabel(label); +} + +// BS: batch_size +// IR: input_rows +// IC: input_cols +// ND: node_depth +// KR: kernel_rows +// KC: kernel_cols +// ST: stride. We use the same stride for both directions. +// PT: padding +// The resulted symbol is too long. Need to use two macros to fit in 80-chars +#define BM_AvgPoolBkCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \ + static void \ + BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \ + int iters) { \ + BM_AvgPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \ + } \ + BENCHMARK( \ + BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) + +// Shapes taken from the 2015/05/16 inception model +BM_AvgPoolBkCPU(32, 35, 35, 192, 3, 3, 1, SAME, 1, "avgpool_grad0_SAME"); +BM_AvgPoolBkCPU(32, 35, 35, 256, 3, 3, 1, SAME, 1, "avgpool_grad1_SAME"); +BM_AvgPoolBkCPU(32, 17, 17, 768, 3, 3, 1, SAME, 1, "avgpool_grad2_SAME"); +BM_AvgPoolBkCPU(32, 17, 17, 1024, 3, 3, 1, SAME, 1, "avgpool_grad3_SAME"); +BM_AvgPoolBkCPU(32, 17, 17, 1152, 3, 3, 1, SAME, 1, "avgpool_grad4_SAME"); +BM_AvgPoolBkCPU(32, 17, 17, 1216, 3, 3, 1, SAME, 1, "avgpool_grad5_SAME"); +BM_AvgPoolBkCPU(32, 17, 17, 1248, 5, 5, 3, VALID, 1, "avgpool_grad6_VALID"); +BM_AvgPoolBkCPU(32, 8, 8, 1760, 3, 3, 1, SAME, 1, "avgpool_grad7_SAME"); +BM_AvgPoolBkCPU(32, 8, 8, 2048, 8, 8, 1, VALID, 1, "avgpool_grad8_VALID"); + +/* +MaxPooling Op +*/ +static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth, + int kernel_rows, int kernel_cols, int stride, + Padding padding, int num_threads, const string& label) { + tensorflow::testing::StopTiming(); + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + + thread::ThreadPool threadpool(Env::Default(), "test", num_threads); + EigenThreadPoolWrapper wrapper(&threadpool); + Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads); + device->set_eigen_cpu_device(&eigen_cpu_device); + + gtl::InlinedVector inputs; + TensorShape shape1({batch_size, rows, cols, depth}); + Tensor input1(DT_FLOAT, shape1); + test::FillIota(&input1, 1.0); + inputs.push_back({nullptr, &input1}); + + // MaxPooling op. + NodeDef maxpool_node_def; + CHECK_EQ(kernel_rows, kernel_cols); + Status status = NodeDefBuilder("maxpool_op", "MaxPool") + .Input(FakeInput()) + .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", padding == VALID ? "VALID" : "SAME") + .Finalize(&maxpool_node_def); + TF_CHECK_OK(status); + std::unique_ptr op(CreateOpKernel( + DEVICE_CPU, device.get(), cpu_allocator(), maxpool_node_def, &status)); + TF_CHECK_OK(status); + OpKernelContext::Params params; + params.device = device.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs; + params.op_kernel = op.get(); + params.output_alloc_attr = [&device, &op, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + + std::unique_ptr maxpool_context(new OpKernelContext(params)); + + op->Compute(maxpool_context.get()); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + delete maxpool_context->release_output(0).tensor; + op->Compute(maxpool_context.get()); + } + tensorflow::testing::StopTiming(); + testing::ItemsProcessed(maxpool_context->mutable_output(0)->NumElements() * + iters); + testing::SetLabel(label); +} + +// BS: batch_size +// IR: input_rows +// IC: input_cols +// ND: node_depth +// KR: kernel_rows +// KC: kernel_cols +// ST: stride. We use the same stride for both directions. +// PT: padding +#define BM_MaxPoolFwdCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \ + static void \ + BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \ + int iters) { \ + BM_MaxPool(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \ + } \ + BENCHMARK( \ + BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) + +// Labels are taken from the 2014-July-24 version of imagenet +BM_MaxPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, VALID, 1, "maxpool0_VALID"); +BM_MaxPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, VALID, 1, "maxpool1_VALID"); +BM_MaxPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, VALID, 1, "maxpool4_VALID"); +BM_MaxPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, VALID, 1, "maxpool10_VALID"); +BM_MaxPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, SAME, 1, "maxpool0_SAME"); +BM_MaxPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 1, "maxpool1_SAME"); +BM_MaxPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 1, "maxpool4_SAME"); +BM_MaxPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 1, "maxpool10_SAME"); +BM_MaxPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, VALID, 4, "maxpool0_VALID"); +BM_MaxPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, VALID, 4, "maxpool1_VALID"); +BM_MaxPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, VALID, 4, "maxpool4_VALID"); +BM_MaxPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, VALID, 4, "maxpool10_VALID"); +BM_MaxPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, SAME, 4, "maxpool0_SAME"); +BM_MaxPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 4, "maxpool1_SAME"); +BM_MaxPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 4, "maxpool4_SAME"); +BM_MaxPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 4, "maxpool10_SAME"); + +static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols, + int depth, int kernel_rows, int kernel_cols, + int stride, Padding padding, int num_threads, + bool use_gpu, const string& label) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + + int out_height, out_width, pad_rows, pad_cols; + Status status = + Get2dOutputSize(rows, cols, kernel_rows, kernel_cols, stride, stride, + padding, &out_height, &out_width, &pad_rows, &pad_cols); + TF_CHECK_OK(status); + + Tensor input_data(DT_FLOAT, TensorShape({batch_size, rows, cols, depth})); + input_data.flat().setRandom(); + Node* input_data_node = ops::Const(input_data, b.opts()); + + Tensor output_data(DT_FLOAT, + TensorShape({batch_size, out_height, out_width, depth})); + output_data.flat().setRandom(); + Node* output_data_node = ops::Const(output_data, b.opts()); + + Tensor output_diff(DT_FLOAT, + TensorShape({batch_size, out_height, out_width, depth})); + output_diff.flat().setRandom(); + Node* output_diff_node = ops::Const(output_diff, b.opts()); + + CHECK_EQ(kernel_rows, kernel_cols); + ops::MaxPoolGrad(input_data_node, output_data_node, output_diff_node, + {1, kernel_rows, kernel_cols, 1} /* ksize */, + {1, stride, stride, 1} /* stride */, + padding == VALID ? "VALID" : "SAME", b.opts()); + Graph* g = new Graph(OpRegistry::Global()); + TF_CHECK_OK(b.ToGraph(g)); + string device = use_gpu ? "gpu" : "cpu"; + test::Benchmark(device, g).Run(iters); + + testing::ItemsProcessed(batch_size * rows * cols * depth * iters); + testing::SetLabel(label); +} + +// BS: batch_size +// IR: input_rows +// IC: input_cols +// ND: node_depth +// KR: kernel_rows +// KC: kernel_cols +// ST: stride. We use the same stride for both directions. +// PT: padding +// The resulted symbol is too long. Need to use two macros to fit in 80-chars +// clang-format off +#define BM_MaxPoolBkGPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \ + static void \ + BM_MaxPoolBk_GPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \ + ##PT##_##TH( \ + int iters) { \ + BM_MaxPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, true, LABEL); \ + } \ + BENCHMARK( \ + BM_MaxPoolBk_GPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \ + ##PT##_##TH) \ + +#define BM_MaxPoolBkCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \ + static void \ + BM_MaxPoolBk_CPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \ + ##PT##_##TH( \ + int iters) { \ + BM_MaxPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, false, LABEL); \ + } \ + BENCHMARK( \ + BM_MaxPoolBk_CPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \ + ##PT##_##TH) +// clang-format on + +// Shapes taken from the 2015/05/16 inception model +BM_MaxPoolBkGPU(32, 147, 147, 64, 3, 3, 2, VALID, 1, "maxpool_grad0_VALID"); +BM_MaxPoolBkGPU(32, 71, 71, 192, 3, 3, 2, VALID, 1, "maxpool_grad1_VALID"); +BM_MaxPoolBkGPU(32, 35, 35, 288, 3, 3, 2, VALID, 1, "maxpool_grad2_VALID"); +BM_MaxPoolBkGPU(32, 17, 17, 1248, 3, 3, 2, VALID, 1, "maxpool_grad3_VALID"); +BM_MaxPoolBkGPU(32, 8, 8, 2048, 3, 3, 2, VALID, 1, "maxpool_grad4_VALID"); + +BM_MaxPoolBkCPU(32, 147, 147, 64, 3, 3, 2, VALID, 1, "maxpool_grad0_VALID"); +BM_MaxPoolBkCPU(32, 71, 71, 192, 3, 3, 2, VALID, 1, "maxpool_grad1_VALID"); +BM_MaxPoolBkCPU(32, 35, 35, 288, 3, 3, 2, VALID, 1, "maxpool_grad2_VALID"); +BM_MaxPoolBkCPU(32, 17, 17, 1248, 3, 3, 2, VALID, 1, "maxpool_grad3_VALID"); +BM_MaxPoolBkCPU(32, 8, 8, 2048, 3, 3, 2, VALID, 1, "maxpool_grad4_VALID"); + +/* +Relu Op +Run benchmark with: +*/ +static void BM_ReluFloat(int iters, int batch_size, int rows, int cols, + int depth, int num_threads, const string& label) { + tensorflow::testing::StopTiming(); + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + + thread::ThreadPool threadpool(Env::Default(), "test", num_threads); + EigenThreadPoolWrapper wrapper(&threadpool); + Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads); + device->set_eigen_cpu_device(&eigen_cpu_device); + + gtl::InlinedVector inputs; + TensorShape shape1({batch_size, rows, cols, depth}); + Tensor input1(DT_FLOAT, shape1); + test::FillIota(&input1, 1.0); + inputs.push_back({nullptr, &input1}); + + // Reluing op. + NodeDef relu_node_def; + Status status = NodeDefBuilder("relu_op", "Relu") + .Input(FakeInput(DT_FLOAT)) + .Finalize(&relu_node_def); + TF_CHECK_OK(status); + std::unique_ptr op(CreateOpKernel( + DEVICE_CPU, device.get(), cpu_allocator(), relu_node_def, &status)); + TF_CHECK_OK(status); + OpKernelContext::Params params; + params.device = device.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs; + params.op_kernel = op.get(); + params.output_alloc_attr = [&device, &op, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + + std::unique_ptr relu_context(new OpKernelContext(params)); + + op->Compute(relu_context.get()); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + delete relu_context->release_output(0).tensor; + op->Compute(relu_context.get()); + } + tensorflow::testing::StopTiming(); + testing::ItemsProcessed(relu_context->mutable_output(0)->NumElements() * + iters); + testing::SetLabel(label); +} + +// BS: batch_size +// IR: input_rows +// IC: input_cols +// ND: node_depth +#define BM_Relu(BS, IR, IC, ND, TH, LABEL) \ + static void BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH(int iters) { \ + BM_ReluFloat(iters, BS, IR, IC, ND, TH, LABEL); \ + } \ + BENCHMARK(BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH) + +BM_Relu(32, 112, 112, 64, 1, "relu0"); +BM_Relu(32, 56, 56, 192, 1, "relu1"); +BM_Relu(32, 28, 28, 352, 1, "relu4"); +BM_Relu(32, 14, 14, 576, 1, "relu10"); +BM_Relu(32, 112, 112, 64, 4, "relu0"); +BM_Relu(32, 56, 56, 192, 4, "relu1"); +BM_Relu(32, 28, 28, 352, 4, "relu4"); +BM_Relu(32, 14, 14, 576, 4, "relu10"); + +static void BM_ImageNetSoftmaxFwd(int iters, int batch_size, int node_depth, + int num_threads, const string& label) { + tensorflow::testing::StopTiming(); + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + + thread::ThreadPool threadpool(Env::Default(), "test", num_threads); + EigenThreadPoolWrapper wrapper(&threadpool); + Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads); + device->set_eigen_cpu_device(&eigen_cpu_device); + + gtl::InlinedVector inputs; + TensorShape shape1({node_depth, batch_size}); + Tensor* input1 = new Tensor(DT_FLOAT, shape1); + test::FillIota(input1, 1.0); + inputs.push_back({nullptr, input1}); + + // Softmax op. + NodeDef softmax_node_def; + TF_CHECK_OK(NodeDefBuilder("softmax_op", "Softmax") + .Input("input", 0, DT_FLOAT) + .Finalize(&softmax_node_def)); + Status status; + std::unique_ptr op(CreateOpKernel( + DEVICE_CPU, device.get(), cpu_allocator(), softmax_node_def, &status)); + TF_CHECK_OK(status); + OpKernelContext::Params params; + params.device = device.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs; + params.op_kernel = op.get(); + params.output_alloc_attr = [&device, &op, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + + std::unique_ptr softmax_context(new OpKernelContext(params)); + + op->Compute(softmax_context.get()); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + delete softmax_context->release_output(0).tensor; + op->Compute(softmax_context.get()); + } + tensorflow::testing::StopTiming(); + testing::ItemsProcessed(softmax_context->mutable_output(0)->NumElements() * + iters); + testing::SetLabel(label); +} + +#define BM_ImageNetSoftmaxFwdCPU(BATCH_SIZE, NODE_DEPTH, TH, LABEL) \ + static void BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH( \ + int iters) { \ + BM_ImageNetSoftmaxFwd(iters, BATCH_SIZE, NODE_DEPTH, TH, LABEL); \ + } \ + BENCHMARK(BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH) + +// Labels are taken from the 2014-July-24 version of imagenet +BM_ImageNetSoftmaxFwdCPU(32, 1008, 1, "softmax32"); +BM_ImageNetSoftmaxFwdCPU(128, 1008, 1, "softmax128"); +BM_ImageNetSoftmaxFwdCPU(32, 1008, 4, "softmax32"); +BM_ImageNetSoftmaxFwdCPU(128, 1008, 4, "softmax128"); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/count_up_to_op.cc b/tensorflow/core/kernels/count_up_to_op.cc new file mode 100644 index 0000000000..7cf4bdb6d0 --- /dev/null +++ b/tensorflow/core/kernels/count_up_to_op.cc @@ -0,0 +1,51 @@ +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +template +class CountUpToOp : public OpKernel { + public: + explicit CountUpToOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("limit", &limit_)); + } + + void Compute(OpKernelContext* context) override { + T before_increment; + { + mutex_lock l(*context->input_ref_mutex(0)); + Tensor tensor = context->mutable_input(0, true); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(tensor.shape()), + errors::InvalidArgument("input is not a scalar: ", + tensor.shape().DebugString())); + T* ptr = &tensor.scalar()(); + before_increment = *ptr; + if (*ptr >= limit_) { + context->SetStatus(errors::OutOfRange("Reached limit of ", limit_)); + return; + } + ++*ptr; + } + // Output if no error. + Tensor* out_tensor; + OP_REQUIRES_OK(context, context->allocate_output("output", TensorShape({}), + &out_tensor)); + out_tensor->scalar()() = before_increment; + } + + private: + T limit_; +}; + +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("CountUpTo").TypeConstraint("T").Device(DEVICE_CPU), \ + CountUpToOp) + +REGISTER(int32); +REGISTER(int64); + +#undef REGISTER + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_abs.cc b/tensorflow/core/kernels/cwise_op_abs.cc new file mode 100644 index 0000000000..5d39b88166 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_abs.cc @@ -0,0 +1,23 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER4(UnaryOp, CPU, "Abs", functor::abs, float, double, int32, int64); +#ifndef __ANDROID__ +REGISTER_KERNEL_BUILDER(Name("ComplexAbs").Device(DEVICE_CPU), + UnaryOp>); +#endif +#if GOOGLE_CUDA +REGISTER3(UnaryOp, GPU, "Abs", functor::abs, float, double, int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Abs") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .TypeConstraint("T"), + UnaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_add.cc b/tensorflow/core/kernels/cwise_op_add.cc new file mode 100644 index 0000000000..a6cd4bddbe --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_add.cc @@ -0,0 +1,21 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER7(BinaryOp, CPU, "Add", functor::add, float, double, int32, int64, int8, + int16, complex64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "Add", functor::add, float, double, int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Add") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_ceil.cc b/tensorflow/core/kernels/cwise_op_ceil.cc new file mode 100644 index 0000000000..0a8f1313f8 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_ceil.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER2(UnaryOp, CPU, "Ceil", functor::ceil, float, double); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Ceil", functor::ceil, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_complex.cc b/tensorflow/core/kernels/cwise_op_complex.cc new file mode 100644 index 0000000000..825181bc35 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_complex.cc @@ -0,0 +1,10 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER_KERNEL_BUILDER(Name("Complex").Device(DEVICE_CPU), + BinaryOp>); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("Complex").Device(DEVICE_GPU), + BinaryOp>); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_conj.cc b/tensorflow/core/kernels/cwise_op_conj.cc new file mode 100644 index 0000000000..ba445d1c3d --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_conj.cc @@ -0,0 +1,10 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER_KERNEL_BUILDER(Name("Conj").Device(DEVICE_CPU), + UnaryOp>); +#if GOOGLE_CUDA +// REGISTER_KERNEL_BUILDER(Name("Conj").Device(DEVICE_GPU), +// UnaryOp>); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_cos.cc b/tensorflow/core/kernels/cwise_op_cos.cc new file mode 100644 index 0000000000..45e24fc2ec --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_cos.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER3(UnaryOp, CPU, "Cos", functor::cos, float, double, complex64); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Cos", functor::cos, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc new file mode 100644 index 0000000000..76d606ed03 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_div.cc @@ -0,0 +1,21 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER5(BinaryOp, CPU, "Div", functor::div, float, double, int32, int64, + complex64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "Div", functor::div, float, double, int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Div") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_equal_to.cc b/tensorflow/core/kernels/cwise_op_equal_to.cc new file mode 100644 index 0000000000..8369299332 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_equal_to.cc @@ -0,0 +1,21 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER5(BinaryOp, CPU, "Equal", functor::equal_to, float, double, int32, + int64, complex64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "Equal", functor::equal_to, float, double, int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Equal") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_exp.cc b/tensorflow/core/kernels/cwise_op_exp.cc new file mode 100644 index 0000000000..b2603a1b4c --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_exp.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER3(UnaryOp, CPU, "Exp", functor::exp, float, double, complex64); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Exp", functor::exp, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_floor.cc b/tensorflow/core/kernels/cwise_op_floor.cc new file mode 100644 index 0000000000..83c8203953 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_floor.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER2(UnaryOp, CPU, "Floor", functor::floor, float, double); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Floor", functor::floor, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_abs.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_abs.cu.cc new file mode 100644 index 0000000000..59436afbc0 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_abs.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY3(abs, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc new file mode 100644 index 0000000000..edf8e0d1a5 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY3(add, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_ceil.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_ceil.cu.cc new file mode 100644 index 0000000000..f24c4b8b73 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_ceil.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(ceil, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc new file mode 100644 index 0000000000..29086b5c71 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY1(make_complex, float); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc new file mode 100644 index 0000000000..cae22cea8e --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +// DEFINE_UNARY1(conj, complex64); // not working +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_cos.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_cos.cu.cc new file mode 100644 index 0000000000..c8412496a8 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_cos.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(cos, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc new file mode 100644 index 0000000000..c581c0487e --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY3(div, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc new file mode 100644 index 0000000000..f994822a74 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY4(equal_to, float, double, int64, complex64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_exp.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_exp.cu.cc new file mode 100644 index 0000000000..caeaa19cef --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_exp.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(exp, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_floor.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_floor.cu.cc new file mode 100644 index 0000000000..0a06ff2978 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_floor.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(floor, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_greater.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_greater.cu.cc new file mode 100644 index 0000000000..e1278e077b --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_greater.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY3(greater, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_greater_equal.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_greater_equal.cu.cc new file mode 100644 index 0000000000..fafcf9b28a --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_greater_equal.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY3(greater_equal, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc new file mode 100644 index 0000000000..0370782c96 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY1(get_imag, complex64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc new file mode 100644 index 0000000000..020abef210 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY3(inverse, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_isfinite.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_isfinite.cu.cc new file mode 100644 index 0000000000..7a3a273af7 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_isfinite.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(isfinite, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_isinf.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_isinf.cu.cc new file mode 100644 index 0000000000..cfc4be3d25 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_isinf.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(isinf, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_isnan.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_isnan.cu.cc new file mode 100644 index 0000000000..c93b74387e --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_isnan.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(isnan, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_less.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_less.cu.cc new file mode 100644 index 0000000000..8e2b28ac60 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_less.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY3(less, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_less_equal.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_less_equal.cu.cc new file mode 100644 index 0000000000..be8e34a58b --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_less_equal.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY3(less_equal, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_log.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_log.cu.cc new file mode 100644 index 0000000000..7d183cce50 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_log.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(log, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_logical_and.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_logical_and.cu.cc new file mode 100644 index 0000000000..ba7046f9f0 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_logical_and.cu.cc @@ -0,0 +1,13 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +template struct BinaryFunctor; +template struct BinaryFunctor; +template struct BinaryFunctor; +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_logical_not.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_logical_not.cu.cc new file mode 100644 index 0000000000..34a43a76ef --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_logical_not.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +template struct UnaryFunctor; +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_logical_or.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_logical_or.cu.cc new file mode 100644 index 0000000000..47a7bd68dc --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_logical_or.cu.cc @@ -0,0 +1,13 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +template struct BinaryFunctor; +template struct BinaryFunctor; +template struct BinaryFunctor; +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_maximum.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_maximum.cu.cc new file mode 100644 index 0000000000..8f7ab90e9a --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_maximum.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY3(maximum, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_minimum.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_minimum.cu.cc new file mode 100644 index 0000000000..75fd7f89b4 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_minimum.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY3(minimum, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc new file mode 100644 index 0000000000..d08a17a94d --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +// No GPU ops for mod yet. +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc new file mode 100644 index 0000000000..e0a6738bef --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY3(mul, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc new file mode 100644 index 0000000000..3031afbb75 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY4(neg, float, double, int32, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc new file mode 100644 index 0000000000..59c76ee88b --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY4(not_equal_to, float, double, int64, complex64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc new file mode 100644 index 0000000000..50177495bc --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY3(pow, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc new file mode 100644 index 0000000000..3b1d465914 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY1(get_real, complex64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc new file mode 100644 index 0000000000..682e2d2d4b --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(rsqrt, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc new file mode 100644 index 0000000000..b5125648e3 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc @@ -0,0 +1,15 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +template struct SelectFunctor; +template struct SelectFunctor; +template struct SelectFunctor; +template struct SelectFunctor; +template struct SelectFunctor; +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc new file mode 100644 index 0000000000..9c250f3071 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(sigmoid, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_sign.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sign.cu.cc new file mode 100644 index 0000000000..f413480ecc --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_sign.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY3(sign, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_sin.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sin.cu.cc new file mode 100644 index 0000000000..6135f3b780 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_sin.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(sin, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc new file mode 100644 index 0000000000..9bdf3b9e30 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(sqrt, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_square.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_square.cu.cc new file mode 100644 index 0000000000..6b900e994d --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_square.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY3(square, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc new file mode 100644 index 0000000000..6fd5ea0d38 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY3(sub, float, double, int64); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc new file mode 100644 index 0000000000..e0393f6c2a --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc @@ -0,0 +1,11 @@ +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY2(tanh, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_greater.cc b/tensorflow/core/kernels/cwise_op_greater.cc new file mode 100644 index 0000000000..9ae31dcdfe --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_greater.cc @@ -0,0 +1,21 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER4(BinaryOp, CPU, "Greater", functor::greater, float, double, int32, + int64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "Greater", functor::greater, float, double, int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Greater") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_greater_equal.cc b/tensorflow/core/kernels/cwise_op_greater_equal.cc new file mode 100644 index 0000000000..be4cc5dc79 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_greater_equal.cc @@ -0,0 +1,22 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER4(BinaryOp, CPU, "GreaterEqual", functor::greater_equal, float, double, + int32, int64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "GreaterEqual", functor::greater_equal, float, double, + int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("GreaterEqual") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_imag.cc b/tensorflow/core/kernels/cwise_op_imag.cc new file mode 100644 index 0000000000..c2432326fc --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_imag.cc @@ -0,0 +1,10 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER_KERNEL_BUILDER(Name("Imag").Device(DEVICE_CPU), + UnaryOp>); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("Imag").Device(DEVICE_GPU), + UnaryOp>); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_inverse.cc b/tensorflow/core/kernels/cwise_op_inverse.cc new file mode 100644 index 0000000000..6af883e755 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_inverse.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER3(UnaryOp, CPU, "Inv", functor::inverse, float, double, complex64); +#if GOOGLE_CUDA +REGISTER3(UnaryOp, GPU, "Inv", functor::inverse, float, double, int64); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_isfinite.cc b/tensorflow/core/kernels/cwise_op_isfinite.cc new file mode 100644 index 0000000000..e52d199a8f --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_isfinite.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER2(UnaryOp, CPU, "IsFinite", functor::isfinite, float, double); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "IsFinite", functor::isfinite, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_isinf.cc b/tensorflow/core/kernels/cwise_op_isinf.cc new file mode 100644 index 0000000000..868204f86e --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_isinf.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER2(UnaryOp, CPU, "IsInf", functor::isinf, float, double); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "IsInf", functor::isinf, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_isnan.cc b/tensorflow/core/kernels/cwise_op_isnan.cc new file mode 100644 index 0000000000..a8f4d60d0f --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_isnan.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER2(UnaryOp, CPU, "IsNan", functor::isnan, float, double); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "IsNan", functor::isnan, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_less.cc b/tensorflow/core/kernels/cwise_op_less.cc new file mode 100644 index 0000000000..3b5f75445c --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_less.cc @@ -0,0 +1,20 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER4(BinaryOp, CPU, "Less", functor::less, float, double, int32, int64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "Less", functor::less, float, double, int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Less") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_less_equal.cc b/tensorflow/core/kernels/cwise_op_less_equal.cc new file mode 100644 index 0000000000..507c7c2908 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_less_equal.cc @@ -0,0 +1,22 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER4(BinaryOp, CPU, "LessEqual", functor::less_equal, float, double, int32, + int64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "LessEqual", functor::less_equal, float, double, + int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("LessEqual") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_log.cc b/tensorflow/core/kernels/cwise_op_log.cc new file mode 100644 index 0000000000..ebc7cbcc4e --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_log.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER3(UnaryOp, CPU, "Log", functor::log, float, double, complex64); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Log", functor::log, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_logical_and.cc b/tensorflow/core/kernels/cwise_op_logical_and.cc new file mode 100644 index 0000000000..a4075088f4 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_logical_and.cc @@ -0,0 +1,10 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER_KERNEL_BUILDER(Name("LogicalAnd").Device(DEVICE_CPU), + BinaryOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("LogicalAnd").Device(DEVICE_GPU), + BinaryOp); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_logical_not.cc b/tensorflow/core/kernels/cwise_op_logical_not.cc new file mode 100644 index 0000000000..b2e97bf70c --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_logical_not.cc @@ -0,0 +1,10 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER_KERNEL_BUILDER(Name("LogicalNot").Device(DEVICE_CPU), + UnaryOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("LogicalNot").Device(DEVICE_GPU), + UnaryOp); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_logical_or.cc b/tensorflow/core/kernels/cwise_op_logical_or.cc new file mode 100644 index 0000000000..0d1df082f7 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_logical_or.cc @@ -0,0 +1,10 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER_KERNEL_BUILDER(Name("LogicalOr").Device(DEVICE_CPU), + BinaryOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("LogicalOr").Device(DEVICE_GPU), + BinaryOp); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_maximum.cc b/tensorflow/core/kernels/cwise_op_maximum.cc new file mode 100644 index 0000000000..c0c9e3f6f5 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_maximum.cc @@ -0,0 +1,21 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER4(BinaryOp, CPU, "Maximum", functor::maximum, float, double, int32, + int64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "Maximum", functor::maximum, float, double, int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Maximum") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_minimum.cc b/tensorflow/core/kernels/cwise_op_minimum.cc new file mode 100644 index 0000000000..4c6bf7df05 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_minimum.cc @@ -0,0 +1,21 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER4(BinaryOp, CPU, "Minimum", functor::minimum, float, double, int32, + int64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "Minimum", functor::minimum, float, double, int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Minimum") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_mod.cc b/tensorflow/core/kernels/cwise_op_mod.cc new file mode 100644 index 0000000000..17f2834030 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_mod.cc @@ -0,0 +1,6 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER2(BinaryOp, CPU, "Mod", functor::mod, int32, int64); +REGISTER2(BinaryOp, CPU, "Mod", functor::fmod, float, double); +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_mul.cc b/tensorflow/core/kernels/cwise_op_mul.cc new file mode 100644 index 0000000000..15f65012cd --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_mul.cc @@ -0,0 +1,21 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER7(BinaryOp, CPU, "Mul", functor::mul, float, double, int32, int64, int8, + int16, complex64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "Mul", functor::mul, float, double, int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Mul") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_neg.cc b/tensorflow/core/kernels/cwise_op_neg.cc new file mode 100644 index 0000000000..3a19b2e94f --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_neg.cc @@ -0,0 +1,9 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER5(UnaryOp, CPU, "Neg", functor::neg, float, double, int32, complex64, + int64); +#if GOOGLE_CUDA +REGISTER4(UnaryOp, GPU, "Neg", functor::neg, float, double, int32, int64); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to.cc b/tensorflow/core/kernels/cwise_op_not_equal_to.cc new file mode 100644 index 0000000000..02d434a1c2 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_not_equal_to.cc @@ -0,0 +1,10 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER5(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, double, + int32, int64, complex64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, double, + int64); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_pow.cc b/tensorflow/core/kernels/cwise_op_pow.cc new file mode 100644 index 0000000000..d10dced85f --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_pow.cc @@ -0,0 +1,9 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER5(BinaryOp, CPU, "Pow", functor::pow, float, double, int32, int64, + complex64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "Pow", functor::pow, float, double, int64); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_real.cc b/tensorflow/core/kernels/cwise_op_real.cc new file mode 100644 index 0000000000..84295a5a16 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_real.cc @@ -0,0 +1,10 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER_KERNEL_BUILDER(Name("Real").Device(DEVICE_CPU), + UnaryOp>); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("Real").Device(DEVICE_GPU), + UnaryOp>); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_rsqrt.cc b/tensorflow/core/kernels/cwise_op_rsqrt.cc new file mode 100644 index 0000000000..a22b1209de --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_rsqrt.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER3(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, double, complex64); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Rsqrt", functor::rsqrt, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc new file mode 100644 index 0000000000..baa821690a --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -0,0 +1,17 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER_SELECT(CPU, "Select", "", float); +REGISTER_SELECT(CPU, "Select", "", double); +REGISTER_SELECT(CPU, "Select", "", int32); +REGISTER_SELECT(CPU, "Select", "", int64); +REGISTER_SELECT(CPU, "Select", "", complex64); +REGISTER_SELECT(CPU, "Select", "", string); +#if GOOGLE_CUDA +REGISTER_SELECT(GPU, "Select", "", float); +REGISTER_SELECT(GPU, "Select", "", double); +REGISTER_SELECT(GPU, "Select", "", int32); +REGISTER_SELECT(GPU, "Select", "", int64); +REGISTER_SELECT(GPU, "Select", "", complex64); +#endif // GOOGLE_CUDA +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sigmoid.cc b/tensorflow/core/kernels/cwise_op_sigmoid.cc new file mode 100644 index 0000000000..e03b5d54dd --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_sigmoid.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER3(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, double, complex64); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sign.cc b/tensorflow/core/kernels/cwise_op_sign.cc new file mode 100644 index 0000000000..59a0bfa1ed --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_sign.cc @@ -0,0 +1,19 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER4(UnaryOp, CPU, "Sign", functor::sign, float, double, int32, int64); +#if GOOGLE_CUDA +REGISTER3(UnaryOp, GPU, "Sign", functor::sign, float, double, int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Sign") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .TypeConstraint("T"), + UnaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sin.cc b/tensorflow/core/kernels/cwise_op_sin.cc new file mode 100644 index 0000000000..e7c87374d7 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_sin.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER3(UnaryOp, CPU, "Sin", functor::sin, float, double, complex64); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Sin", functor::sin, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sqrt.cc b/tensorflow/core/kernels/cwise_op_sqrt.cc new file mode 100644 index 0000000000..f43241264a --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_sqrt.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER3(UnaryOp, CPU, "Sqrt", functor::sqrt, float, double, complex64); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Sqrt", functor::sqrt, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_square.cc b/tensorflow/core/kernels/cwise_op_square.cc new file mode 100644 index 0000000000..510fda49aa --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_square.cc @@ -0,0 +1,9 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER5(UnaryOp, CPU, "Square", functor::square, float, double, int32, + complex64, int64); +#if GOOGLE_CUDA +REGISTER3(UnaryOp, GPU, "Square", functor::square, float, double, int64); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc new file mode 100644 index 0000000000..c3c5952f8d --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_sub.cc @@ -0,0 +1,21 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER5(BinaryOp, CPU, "Sub", functor::sub, float, double, int32, int64, + complex64); +#if GOOGLE_CUDA +REGISTER3(BinaryOp, GPU, "Sub", functor::sub, float, double, int64); +#endif + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Sub") + .Device(DEVICE_GPU) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_tanh.cc b/tensorflow/core/kernels/cwise_op_tanh.cc new file mode 100644 index 0000000000..31f4743449 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_tanh.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER3(UnaryOp, CPU, "Tanh", functor::tanh, float, double, complex64); +#if GOOGLE_CUDA +REGISTER2(UnaryOp, GPU, "Tanh", functor::tanh, float, double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h new file mode 100644 index 0000000000..7d818cfbbf --- /dev/null +++ b/tensorflow/core/kernels/cwise_ops.h @@ -0,0 +1,607 @@ +#ifndef TENSORFLOW_KERNELS_CWISE_OPS_H_ +#define TENSORFLOW_KERNELS_CWISE_OPS_H_ + +#include +#include +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +// The following functors (sign, tanh, sigmoid, etc.) are not defined +// by Eigen. When their equivalent are added into the Eigen, we can +// replace them using type aliases. + +namespace Eigen { +namespace internal { + +template +struct scalar_sign_op { + // TODO(zhifengc): this only works for real types. In theory, + // sign(x) = x / |x| works for both real and complex values. + EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { + return T(x > T(0)) - T(x < T(0)); + } +}; + +// TODO(zhifengc): Eigen::internal::pow_impl does not have proper +// EIGEN host/device decoration. We duplicate code here for now. +template +struct pow { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T + operator()(const T& x, const T& y) const { + return std::pow(x, y); + } +}; + +template +struct pow { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(T x, T y) const { + T res(1); + if (y & 1) res *= x; + y >>= 1; + while (y) { + x *= x; + if (y & 1) res *= x; + y >>= 1; + } + return res; + } +}; + +template +struct scalar_pow2_op : pow::IsInteger> {}; + +template +struct functor_traits > { + enum { + Cost = 5 * NumTraits::MulCost, + PacketAccess = false, + }; +}; + +template +struct scalar_fmod2_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_fmod2_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a, + const T& b) const { + return fmod(a, b); + } +}; + +template +struct scalar_mod2_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_mod2_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T + operator()(const T& a, const T& b) const { + return a % b; + } +}; + +template +struct functor_traits > { + enum { + Cost = 5, // Roughly the cost of a div + PacketAccess = false, + }; +}; + +// scalar_left and scalar_right are template helpers to partially +// apply a binary function. +// +// Suppose Binary is a binary functor f(x, y), scalar_left<> is a +// unary functor g_x(y) = f(x, y), where x is provided via the +// constructor. Similarly, scalar_right<> is a unary functor g_y(x) = +// f(x, y). + +template ::PacketAccess> +struct scalar_left { + typedef Tout result_type; + const Tin* left; + EIGEN_DEVICE_FUNC inline scalar_left( + const scalar_left& other) // NOLINT(runtime/explicit) + : left(other.left) {} + EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c) : left(c) {} + EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const { + return Binary()(*left, right); + } +}; + +template +struct scalar_left { + typedef Tout result_type; + const Tin* left; + EIGEN_DEVICE_FUNC inline scalar_left( + const scalar_left& other) // NOLINT(runtime/explicit) + : left(other.left) {} + EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c) : left(c) {} + EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const { + return Binary()(*left, right); + } + + template + EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { + const Packet left_packet = Eigen::internal::pset1(*left); + return Binary().packetOp(left_packet, right_packet); + } +}; + +template +struct functor_traits > { + enum { + Cost = functor_traits::Cost, + PacketAccess = functor_traits::PacketAccess, + }; +}; + +template ::PacketAccess> +struct scalar_right { + typedef Tout result_type; + const Tin* right; + EIGEN_DEVICE_FUNC inline scalar_right( + const scalar_right& other) // NOLINT(runtime/explicit) + : right(other.right) {} + EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c) : right(c) {} + EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const { + return Binary()(left, *right); + } +}; + +template +struct scalar_right { + typedef Tout result_type; + const Tin* right; + EIGEN_DEVICE_FUNC inline scalar_right( + const scalar_right& other) // NOLINT(runtime/explicit) + : right(other.right) {} + EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c) : right(c) {} + EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const { + return Binary()(left, *right); + } + + template + EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { + const Packet right_packet = Eigen::internal::pset1(*right); + return Binary().packetOp(left_packet, right_packet); + } +}; + +template +struct functor_traits > { + enum { + Cost = functor_traits::Cost, + PacketAccess = functor_traits::PacketAccess, + }; +}; + +// similar to std::equal_to, but with the DEVICE_FUNC qualifier +template +struct equal_to : std::binary_function { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + bool operator()(const T& x, const T& y) const { return x == y; } +}; + +// similar to std::not_equal_to, but with the DEVICE_FUNC qualifier +template +struct not_equal_to : std::binary_function { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + bool operator()(const T& x, const T& y) const { return x != y; } +}; + +// similar to std::greater, but with the DEVICE_FUNC qualifier +template +struct greater : std::binary_function { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + bool operator()(const T& x, const T& y) const { return x > y; } +}; + +// similar to std::less, but with the DEVICE_FUNC qualifier +template +struct less : std::binary_function { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + bool operator()(const T& x, const T& y) const { return x < y; } +}; + +// similar to std::greater_equal, but with the DEVICE_FUNC qualifier +template +struct greater_equal : std::binary_function { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + bool operator()(const T& x, const T& y) const { return x >= y; } +}; + +// similar to std::less_equal, but with the DEVICE_FUNC qualifier +template +struct less_equal : std::binary_function { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + bool operator()(const T& x, const T& y) const { return x <= y; } +}; + +} // end namespace internal +} // end namespace Eigen + +namespace tensorflow { +namespace functor { + +//////////////////////////////////////////////////////////////////////////////// +// Helpers +//////////////////////////////////////////////////////////////////////////////// + +// Base template for functors whose input scalar type is T and +// output scalar type is R. +template +struct base { + // func defines operator() and its vectorized version packetOp(). + typedef F func; + + // If true, the functor's corresponding binary op will instantiate + // specialized kernels to perform an optimized broadcast + // operation. Each functor for which this is enabled increases the + // code size, so by default this is disabled for binary functors and + // is enabled on a per-op basis as needed. + static const bool use_bcast_optimization = false; + + // operator() has the signature: + // out_type operator()(in_type in0, in_type in1 ...) + typedef R out_type; + typedef T in_type; + + // TensorFlow provides tensor-ized version of "func". Roughly + // speaking, the tensorflow operation has the signature: + // tout_type op(tin_type in0) + // tout_type op(tin_type in0, tin_type in1) + // tout_type op(tin_type in0, in_type scalar) + typedef typename TTypes::Flat tout_type; + typedef typename TTypes::ConstFlat tin_type; + typedef typename TTypes::ConstScalar tscalar_type; +}; + +// For now, we only apply certain speed optimization for +// float/double's broadcast binary op. +template +struct use_bcast_optimization { + static const bool value = false; +}; + +template <> +struct use_bcast_optimization { + static const bool value = true; +}; + +template <> +struct use_bcast_optimization { + static const bool value = true; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Unary functors +//////////////////////////////////////////////////////////////////////////////// + +// abs(x) = |x| +// neg(x) = - x +// inverse(x) = 1 / x +// square(x) = x^2 +// sqrt(x) = x^(1/2) +// rsqrt(x) = x^(-1/2) +// exp(x) = e^x +// log(x) = natural logrithm of x +// tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +// sigmoid = 1 / (1 + exp(-x)) // a.k.a, logistic +// +// NOTE: We may eventually implement common functions used in NN +// here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc. +// For reference, see speech/lstm/eigen_functors.h. + +template +struct abs : base, + typename Eigen::internal::scalar_abs_op::result_type> {}; + +template +struct neg : base > {}; + +template +struct inverse : base > {}; + +template +struct square : base > {}; + +template +struct sqrt : base > {}; + +template +struct rsqrt : base > {}; + +template +struct exp : base > {}; + +template +struct log : base > {}; + +template +struct sign : base > {}; + +template +struct tanh : base > {}; + +template +struct sigmoid : base > {}; + +template +struct sin : base > {}; + +template +struct cos : base > {}; + +struct logical_not : base > {}; + +namespace impl { + +#ifndef __CUDACC__ +// Uses STL std cmath functions. +template +bool isinf(T v) { + return std::isinf(v); +} + +template +bool isnan(T v) { + return std::isnan(v); +} + +template +bool isfinite(T v) { + return std::isfinite(v); +} + +template +T floor(T v) { + return std::floor(v); +} + +template +T ceil(T v) { + return std::ceil(v); +} +#else +// Uses CUDA's functions for float and double. +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isinf(T v) { + return ::isinf(v); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isnan(T v) { + return ::isnan(v); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isfinite(T v) { + return ::isfinite(v); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T floor(T v) { + return ::floor(v); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T ceil(T v) { + return ::ceil(v); +} +#endif +} // end namespace impl + +// NOTE: std::isinf, std::isnan, std::isfinite are plain function. +// Therefore we need to wrap them in functors to be used with Eigen's +// type system. + +template +struct isinf_func { + typedef bool result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(T x) const { + return impl::isinf(x); + } +}; + +template +struct isinf : base, bool> {}; + +template +struct isnan_func { + typedef bool result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(T x) const { + return impl::isnan(x); + } +}; + +template +struct isnan : base, bool> {}; + +template +struct isfinite_func { + typedef bool result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(T x) const { + return impl::isfinite(x); + } +}; + +template +struct isfinite : base, bool> {}; + +template +struct floor_func { + typedef T result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(T x) const { + return impl::floor(x); + } +}; + +template +struct floor : base > {}; + +template +struct ceil_func { + typedef T result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(T x) const { + return impl::ceil(x); + } +}; + +template +struct ceil : base > {}; + +//////////////////////////////////////////////////////////////////////////////// +// Binary functors +//////////////////////////////////////////////////////////////////////////////// + +// Binary functors: +// +// add(x, y) = x + y +// sub(x, y) = x - y +// mul(x, y) = x * y +// div(x, y) = x / y +// mod(x, y) = x % y (int32 and int64 only) +// fmod(x, y) = fmod(x, y) (float and double only) +// pow(x, y) = x ^ y +// maximum(x, y) = x > y ? x : y +// minimum(x, y) = x < y ? x : y + +template +struct add : base > { + static const bool use_bcast_optimization = true; +}; + +template +struct sub : base > { + static const bool use_bcast_optimization = true; +}; + +template +struct mul : base > {}; + +template +struct div : base > {}; + +template +struct fmod : base > {}; + +template +struct mod : base > {}; + +template +struct pow : base > {}; + +template +struct maximum : base > {}; + +template +struct minimum : base > {}; + +template +struct less : base, bool> {}; + +template +struct less_equal : base, bool> {}; + +template +struct greater : base, bool> {}; + +template +struct greater_equal : base, bool> {}; + +template +struct equal_to : base, bool> {}; + +template +struct not_equal_to : base, bool> {}; + +struct logical_and : base {}; + +struct logical_or : base {}; + +template +struct make_complex_func { + typedef std::complex result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + result_type operator()(T real, T imag) const { + return std::complex(real, imag); + } +}; + +template +struct make_complex : base, std::complex > {}; + +template +struct get_real + : base, typename T::value_type> {}; + +template +struct get_imag + : base, typename T::value_type> {}; + +template +struct conj : base > {}; + +//////////////////////////////////////////////////////////////////////////////// +// Functors takes 1 or 2 tensors, computes the base functor on +// coefficient of the input tensors and puts the results in the output +// tensor. +//////////////////////////////////////////////////////////////////////////////// +template +struct UnaryFunctor { + // Computes on device "d": out[i] = Functor(in[i]) + void operator()(const Device& d, typename Functor::tout_type out, + typename Functor::tin_type in); +}; + +template +struct BinaryFunctor { + // Computes on device "d": out[i] = Functor(in0[i], in1[i]) + void operator()(const Device& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1); + + // Computes on device "d": out[i] = Functor(scalar[0], in[i]) + void Left(const Device& d, typename Functor::tout_type out, + typename Functor::tscalar_type scalar, + typename Functor::tin_type in); + + // Computes on device "d": out[i] = Functor(in[i], scalar[0]) + void Right(const Device& d, typename Functor::tout_type out, + typename Functor::tin_type in, + typename Functor::tscalar_type scalar); + + // Computes on device "d": + // out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast01)) + // + // TODO(zhifengc): makes BCast a template member function on NDIMS + // instead making BinaryFunctor templates on NDIMS. + void BCast(const Device& d, + typename TTypes::Tensor out, + typename TTypes::ConstTensor in0, + typename Eigen::array bcast0, + typename TTypes::ConstTensor in1, + typename Eigen::array bcast1); +}; + +template +bool AllOne(const typename Eigen::array& a) { + for (int i = 0; i < a.size(); ++i) { + if (a[i] != 1) return false; + } + return true; +} + +template +struct SelectFunctor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat cond_flat, + typename TTypes::ConstFlat then_flat, + typename TTypes::ConstFlat else_flat); +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_KERNELS_CWISE_OPS_H_ diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc new file mode 100644 index 0000000000..f86d2ddd9a --- /dev/null +++ b/tensorflow/core/kernels/cwise_ops_common.cc @@ -0,0 +1,42 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { + +BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out, + DataType in) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out})); +} + +void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) { + ctx->SetStatus(errors::Unimplemented( + "Broadcast between ", ctx->input(0).shape().ShortDebugString(), " and ", + ctx->input(1).shape().ShortDebugString(), " is not supported yet.")); +} + +static BCast::Vec FromShape(const TensorShape& shape) { + BCast::Vec ret; + for (int i = 0; i < shape.dims(); ++i) ret.push_back(shape.dim_size(i)); + return ret; +} + +static TensorShape ToShape(const BCast::Vec& vec) { + TensorShape shape; + for (auto elem : vec) shape.AddDim(elem); + return shape; +} + +BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx) + : bcast(FromShape(ctx->input(0).shape()), + FromShape(ctx->input(1).shape())) { + if (!bcast.IsValid()) { + ctx->SetStatus(errors::InvalidArgument( + "Incompatible shapes: ", ctx->input(0).shape().ShortDebugString(), + " vs. ", ctx->input(1).shape().ShortDebugString())); + return; + } + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, ToShape(bcast.output_shape()), &out)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h new file mode 100644 index 0000000000..cf848b86d1 --- /dev/null +++ b/tensorflow/core/kernels/cwise_ops_common.h @@ -0,0 +1,390 @@ +#ifndef TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_ +#define TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_ + +// See docs in ../ops/math_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/cwise_ops.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +class BinaryOpShared : public OpKernel { + public: + explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in); + + protected: + struct BinaryOpState { + // Sets up bcast with the shape of in0 and in1, ensures that the bcast + // is valid, and if so, allocates out using ctx->output(...). + // Caller must check ctx->status() upon return for non-ok status. + // If ctx->status().ok() is true, then out is guaranteed to be allocated. + BinaryOpState(OpKernelContext* ctx); + + BCast bcast; + Tensor* out = nullptr; + }; + + template + static Eigen::array ToIndexArray( + const BCast::Vec& vec) { + CHECK_EQ(vec.size(), NDIMS); + Eigen::array ret; + for (int i = 0; i < NDIMS; ++i) ret[i] = vec[i]; + return ret; + } + void SetUnimplementedError(OpKernelContext* ctx); +}; + +// Coefficient-wise binary operations: +// Device: E.g., CPUDevice, GPUDevice. +// Functor: defined in cwise_functors.h. E.g., functor::add2. +template +class BinaryOp : public BinaryOpShared { + public: + typedef typename Functor::in_type Tin; // Input scalar data type. + typedef typename Functor::out_type Tout; // Output scalar data type. + + explicit BinaryOp(OpKernelConstruction* ctx) + : BinaryOpShared(ctx, DataTypeToEnum::v(), + DataTypeToEnum::v()) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& in0 = ctx->input(0); + const Tensor& in1 = ctx->input(1); + // 'state': Shared helper not dependent on T to reduce code size + BinaryOpState state(ctx); + if (!ctx->status().ok()) return; + Tensor* out = state.out; + BCast* bcast = &state.bcast; + if (out->NumElements() == 0) { + return; + } + const int ndims = bcast->x_reshape().size(); + if (ndims <= 1) { + if (in1.NumElements() == 1) { + // tensor op scalar + functor::BinaryFunctor().Right( + ctx->eigen_device(), out->flat(), in0.flat(), + in1.scalar()); + return; + } + if (in0.NumElements() == 1) { + // scalar op tensor + functor::BinaryFunctor().Left( + ctx->eigen_device(), out->flat(), in0.scalar(), + in1.flat()); + return; + } + functor::BinaryFunctor()( + ctx->eigen_device(), out->flat(), in0.flat(), + in1.flat()); + return; + } + + if (ndims == 2) { + functor::BinaryFunctor().BCast( + ctx->eigen_device(), + out->shaped(bcast->result_shape()), + in0.shaped(bcast->x_reshape()), + ToIndexArray<2>(bcast->x_bcast()), + in1.shaped(bcast->y_reshape()), + ToIndexArray<2>(bcast->y_bcast())); + return; + } + + if (ndims == 3) { + functor::BinaryFunctor().BCast( + ctx->eigen_device(), + out->shaped(bcast->result_shape()), + in0.shaped(bcast->x_reshape()), + ToIndexArray<3>(bcast->x_bcast()), + in1.shaped(bcast->y_reshape()), + ToIndexArray<3>(bcast->y_bcast())); + return; + } + + SetUnimplementedError(ctx); + } + + private: +}; + +// Coefficient-wise unary operations: +// Device: E.g., CPUDevice, GPUDevice. +// Functor: defined in cwise_functors.h. E.g., functor::sqrt. +template +class UnaryOp : public OpKernel { + public: + typedef typename Functor::in_type Tin; // Input scalar data type. + typedef typename Functor::out_type Tout; // Output scalar data type. + // Tin may be different from Tout. E.g., abs: complex64 -> float + + explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + auto in = DataTypeToEnum::v(); + auto out = DataTypeToEnum::v(); + OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out})); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& inp = ctx->input(0); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); + functor::UnaryFunctor()( + ctx->eigen_device(), out->flat(), inp.flat()); + } +}; + +// Coefficient-wise select operation. +// Device: E.g., CPUDevice, GPUDevice. +template +class SelectOp : public OpKernel { + public: + explicit SelectOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + auto dt = DataTypeToEnum::v(); + OP_REQUIRES_OK(ctx, ctx->MatchSignature({DT_BOOL, dt, dt}, {dt})); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& in0 = ctx->input(0); + const Tensor& in1 = ctx->input(1); + const Tensor& in2 = ctx->input(2); + if (!ctx->ValidateInputsAreSameShape(this)) return; + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out)); + functor::SelectFunctor func; + func(ctx->eigen_device(), out->flat(), in0.flat(), + in1.flat(), in2.flat()); + } +}; + +namespace functor { + +// For CPUDevice, we do operations inline if the resulting tensor is +// modestly sized. +static bool DoInline(size_t size) { return size <= 32768; } + +template +void Assign(const D& d, OUT out, RHS rhs) { + if (DoInline(out.size())) { + out = rhs; + } else { + out.device(d) = rhs; + } +} + +// Partial specialization of BinaryFunctor. +template +struct BinaryFunctor { + void operator()(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1) { + Assign(d, out, in0.binaryExpr(in1, typename Functor::func())); + } + + void Left(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tscalar_type scalar, + typename Functor::tin_type in) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_left Unary; + Assign(d, out, in.unaryExpr(Unary(scalar.data()))); + } + + void Right(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in, + typename Functor::tscalar_type scalar) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_right Unary; + Assign(d, out, in.unaryExpr(Unary(scalar.data()))); + } + +#if !defined(EIGEN_HAS_INDEX_LIST) + inline Eigen::DSizes NByOne(int n) { + return Eigen::DSizes(n, 1); + } + inline Eigen::DSizes OneByM(int m) { + return Eigen::DSizes(1, m); + } +#else + inline Eigen::IndexList> NByOne(int n) { + Eigen::IndexList> ret; + ret.set(0, n); + return ret; + } + inline Eigen::IndexList, int> OneByM(int m) { + Eigen::IndexList, int> ret; + ret.set(1, m); + return ret; + } +#endif + + void BCast(const CPUDevice& dev, + typename TTypes::Tensor out, + typename TTypes::ConstTensor in0, + typename Eigen::array bcast0, + typename TTypes::ConstTensor in1, + typename Eigen::array bcast1) { + typedef typename Functor::in_type T; + typename Functor::func func; + if ((NDIMS == 2) && Functor::use_bcast_optimization && + use_bcast_optimization::value) { + // Optimize for speed by using Eigen::type2index and avoid + // .broadcast() when we know its a no-op. + // + // Here, we need to handle 6 cases depending on how many "1" + // exist in in0 and in1's shapes (4 numbers in total). It's not + // possible that two shapes have more than 2 1s because those + // are simplified to NDIMS==1 case. + // + // Because this optimization increases the binary size for each + // Functor (+, -, *, /, <, <=, etc.), type and ndim combination. + // we only apply such optimization for selected ops/types/ndims. + // + // Because NDIMS, Functor::use_broadcast_optimization and + // use_broadcast_optimization are compile-time constant, gcc + // does a decent job avoiding generating code when conditions + // are not met. + const int a = in0.dimension(0); // in0 is shape [a, b] + const int b = in0.dimension(1); + const int c = in1.dimension(0); // in1 is shape [c, d] + const int d = in1.dimension(1); + if ((a == 1) && (d == 1)) { + auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c)); + auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b)); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + if ((b == 1) && (c == 1)) { + auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d)); + auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a)); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + if (a == 1) { + auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c)); + auto rhs = in1; + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + if (b == 1) { + auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d)); + auto rhs = in1; + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + if (c == 1) { + auto lhs = in0; + auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a)); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + if (d == 1) { + auto lhs = in0; + auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b)); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + + const bool bcast0_all_one = AllOne(bcast0); + const bool bcast1_all_one = AllOne(bcast1); + if (bcast0_all_one && !bcast1_all_one) { + auto lhs = in0; // No need to do broadcast for in0 + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + + if (!bcast0_all_one && bcast1_all_one) { + auto lhs = in0.broadcast(bcast0); + auto rhs = in1; // No need to do broadcast for in1 + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + } + + // Fallback path. Always work and probably slower. + auto lhs = in0.broadcast(bcast0); + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + } +}; + +// Partial specialization of UnaryFunctor. +template +struct UnaryFunctor { + void operator()(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in) { + Assign(d, out, in.unaryExpr(typename Functor::func())); + } +}; + +template +struct SelectFunctor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat cond_flat, + typename TTypes::ConstFlat then_flat, + typename TTypes::ConstFlat else_flat) { + Assign(d, out, cond_flat.select(then_flat, else_flat)); + } +}; + +} // end namespace functor + +#define REGISTER_SELECT(D, N, F, T) \ + REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint("T"), \ + SelectOp) + +#define REGISTER(OP, D, N, F, T) \ + REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint("T"), \ + OP>); + +// Macros to register kernels for multiple types (T0, T1, etc.) on +// device type "D" (CPU or GPU) for operatin "N" (e.g., sqrt) using +// the functor "F" (e.g., functor:sqrt). + +#ifdef __ANDROID__ +// On Android, only register the first type (float) +#define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0) +#define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0) +#define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0) +#define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0) +#define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0) +#define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \ + REGISTER(OP, D, N, F, T0) +#else // !__ANDROID__ +#define REGISTER2(OP, D, N, F, T0, T1) \ + REGISTER(OP, D, N, F, T0) \ + REGISTER(OP, D, N, F, T1) +#define REGISTER3(OP, D, N, F, T0, T1, T2) \ + REGISTER2(OP, D, N, F, T0, T1) \ + REGISTER(OP, D, N, F, T2) +#define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ + REGISTER2(OP, D, N, F, T0, T1) \ + REGISTER2(OP, D, N, F, T2, T3) +#define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \ + REGISTER3(OP, D, N, F, T0, T1, T2) \ + REGISTER2(OP, D, N, F, T3, T4) +#define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \ + REGISTER3(OP, D, N, F, T0, T1, T2) \ + REGISTER3(OP, D, N, F, T3, T4, T5) +#define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \ + REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ + REGISTER3(OP, D, N, F, T4, T5, T6) +#endif // __ANDROID__ + +} // end namespace tensorflow + +#endif // TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_ diff --git a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h new file mode 100644 index 0000000000..b0dc027144 --- /dev/null +++ b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h @@ -0,0 +1,135 @@ +#if !GOOGLE_CUDA +#error This file must only be included when building with Cuda support +#endif + +#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_ +#define TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_ + +#define EIGEN_USE_GPU + +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/kernels/cwise_ops.h" +#include "tensorflow/core/framework/tensor_types.h" + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; +typedef std::complex complex64; + +// Partial specialization of UnaryFunctor. +template +struct UnaryFunctor { + void operator()(const GPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in) { + out.device(d) = in.unaryExpr(typename Functor::func()); + } +}; + +// Partial specialization of BinaryFunctor. +template +struct BinaryFunctor { + void operator()(const GPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1) { + out.device(d) = in0.binaryExpr(in1, typename Functor::func()); + } + + void Left(const GPUDevice& d, typename Functor::tout_type out, + typename Functor::tscalar_type scalar, + typename Functor::tin_type in) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_left Unary; + out.device(d) = in.unaryExpr(Unary(scalar.data())); + } + + void Right(const GPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in, + typename Functor::tscalar_type scalar) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_right Unary; + out.device(d) = in.unaryExpr(Unary(scalar.data())); + } + + void BCast(const GPUDevice& d, + typename TTypes::Tensor out, + typename TTypes::ConstTensor in0, + typename Eigen::array bcast0, + typename TTypes::ConstTensor in1, + typename Eigen::array bcast1) { + typedef typename Functor::in_type T; + typename Functor::func func; + if ((NDIMS == 2) && Functor::use_bcast_optimization && + use_bcast_optimization::value) { + const bool bcast0_all_one = AllOne(bcast0); + const bool bcast1_all_one = AllOne(bcast1); + if (bcast0_all_one && !bcast1_all_one) { + out.device(d) = in0.binaryExpr(in1.broadcast(bcast1), func); + return; + } + if (!bcast0_all_one && bcast1_all_one) { + out.device(d) = in0.broadcast(bcast0).binaryExpr(in1, func); + return; + } + } + out.device(d) = + in0.broadcast(bcast0).binaryExpr(in1.broadcast(bcast1), func); + } +}; + +template +struct SelectFunctor { + void operator()(const GPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat cond_flat, + typename TTypes::ConstFlat then_flat, + typename TTypes::ConstFlat else_flat) { + out.device(d) = cond_flat.select(then_flat, else_flat); + } +}; + +// Macros to explicitly instantiate kernels on GPU for multiple types +// (T0, T1, etc.) for UnaryFunctor (e.g., functor:sqrt). +#define DEFINE_UNARY1(F, T) template struct UnaryFunctor > +#define DEFINE_UNARY2(F, T0, T1) \ + DEFINE_UNARY1(F, T0); \ + DEFINE_UNARY1(F, T1) +#define DEFINE_UNARY3(F, T0, T1, T2) \ + DEFINE_UNARY2(F, T0, T1); \ + DEFINE_UNARY1(F, T2) +#define DEFINE_UNARY4(F, T0, T1, T2, T3) \ + DEFINE_UNARY2(F, T0, T1); \ + DEFINE_UNARY2(F, T2, T3) +#define DEFINE_UNARY5(F, T0, T1, T2, T3, T4) \ + DEFINE_UNARY2(F, T0, T1); \ + DEFINE_UNARY3(F, T2, T3, T4) + +// Macros to explicitly instantiate kernels on GPU for multiple types +// (T0, T1, etc.) for BinaryFunctor. +#define DEFINE_BINARY1(F, T) \ + template struct BinaryFunctor, 1>; \ + template struct BinaryFunctor, 2>; \ + template struct BinaryFunctor, 3> +#define DEFINE_BINARY2(F, T0, T1) \ + DEFINE_BINARY1(F, T0); \ + DEFINE_BINARY1(F, T1) +#define DEFINE_BINARY3(F, T0, T1, T2) \ + DEFINE_BINARY2(F, T0, T1); \ + DEFINE_BINARY1(F, T2) +#define DEFINE_BINARY4(F, T0, T1, T2, T3) \ + DEFINE_BINARY2(F, T0, T1); \ + DEFINE_BINARY2(F, T2, T3) +#define DEFINE_BINARY5(F, T0, T1, T2, T3, T4) \ + DEFINE_BINARY2(F, T0, T1); \ + DEFINE_BINARY3(F, T2, T3, T4) + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_ diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc new file mode 100644 index 0000000000..56af248117 --- /dev/null +++ b/tensorflow/core/kernels/cwise_ops_test.cc @@ -0,0 +1,167 @@ +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include + +namespace tensorflow { + +// Creates a Graph which applies a unary "func" on a 3D float tensor +// of "num" elements. +static Graph* Unary(const string& func, int num) { + RequireDefaultOps(); + Graph* g = new Graph(OpRegistry::Global()); + Tensor data(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)})); + CHECK_GT(data.NumElements(), 0); + data.flat().setRandom(); + test::graph::Unary(g, func, test::graph::Constant(g, data), 0); + return g; +} + +static int kRows = 100000; + +static int RowsAndColsArg(int r, int c) { return r * kRows + c; } +static int RowsFromArg(int arg) { return (arg / kRows); } +static int ColsFromArg(int arg) { return (arg % kRows); } + +#define BM_UNARY(DEVICE, FUNC) \ + static void BM_##DEVICE##_##FUNC(int iters, int num) { \ + const int64 tot = static_cast(iters) * num; \ + testing::ItemsProcessed(tot); \ + testing::BytesProcessed(tot * sizeof(float)); \ + test::Benchmark(#DEVICE, Unary(#FUNC, num)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_##FUNC)->Range(4 << 10, 1 << 20); + +BM_UNARY(cpu, Floor); +BM_UNARY(gpu, Floor); + +// data func scalar. +static Graph* BinaryScalar(int num, const string& func) { + RequireDefaultOps(); + Graph* g = new Graph(OpRegistry::Global()); + Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)})); + lhs.flat().setRandom(); + Tensor rhs(DT_FLOAT, TensorShape({})); + rhs.flat().setRandom(); + test::graph::Binary(g, func, test::graph::Constant(g, lhs), + test::graph::Constant(g, rhs)); + return g; +} + +#define BM_BINARY_SCALAR(DEVICE, FUNC) \ + static void BM_##DEVICE##_##FUNC##_scalar(int iters, int num) { \ + const int64 tot = static_cast(iters) * num; \ + testing::ItemsProcessed(tot); \ + testing::BytesProcessed(tot * sizeof(float)); \ + test::Benchmark(#DEVICE, BinaryScalar(num, #FUNC)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_##FUNC##_scalar) \ + ->Arg(4096) /* must >= 4096 */ \ + ->Arg(32768) \ + ->Arg(131072) \ + ->Arg(1048576); + +BM_BINARY_SCALAR(cpu, Less); +BM_BINARY_SCALAR(gpu, Less); +BM_BINARY_SCALAR(cpu, Add); +BM_BINARY_SCALAR(gpu, Add); +#undef BM_BINARY_SCALAR + +static Graph* BiasAdd(int rows, int cols) { + RequireDefaultOps(); + Graph* g = new Graph(OpRegistry::Global()); + Tensor lhs(DT_FLOAT, TensorShape({rows, cols})); + lhs.flat().setRandom(); + TensorShape rhs_shape; + rhs_shape = TensorShape({cols}); + Tensor rhs(DT_FLOAT, rhs_shape); + rhs.flat().setRandom(); + test::graph::Binary(g, "BiasAdd", test::graph::Constant(g, lhs), + test::graph::Constant(g, rhs)); + return g; +} + +#define BM_BIAS_ADD(DEVICE, R, C) \ + static void BM_##DEVICE##_BiasAdd_R##R##_C##C(int iters, int arg) { \ + const int rows = RowsFromArg(arg); \ + const int cols = ColsFromArg(arg); \ + const int64 tot = static_cast(iters) * rows * cols; \ + testing::ItemsProcessed(tot); \ + testing::BytesProcessed(tot * sizeof(float)); \ + test::Benchmark(#DEVICE, BiasAdd(rows, cols)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_BiasAdd_R##R##_C##C)->Arg(RowsAndColsArg(R, C)); + +#define BM_BIAS_ADD_ALL(DEVICE) \ + BM_BIAS_ADD(DEVICE, 512, 2048); \ + BM_BIAS_ADD(DEVICE, 512, 4096); \ + BM_BIAS_ADD(DEVICE, 2048, 512); \ + BM_BIAS_ADD(DEVICE, 4096, 512); + +BM_BIAS_ADD_ALL(cpu); +BM_BIAS_ADD_ALL(gpu); +#undef BM_BIAS_ADD_ALL +#undef BM_BIAS_ADD + +static Graph* BcastAdd(int rows, int cols, int dim) { + RequireDefaultOps(); + Graph* g = new Graph(OpRegistry::Global()); + Tensor lhs(DT_FLOAT, TensorShape({rows, cols})); + lhs.flat().setRandom(); + TensorShape rhs_shape; + if (dim == 0) { + rhs_shape = TensorShape({rows, 1}); + } else { + rhs_shape = TensorShape({cols}); + } + Tensor rhs(DT_FLOAT, rhs_shape); + rhs.flat().setRandom(); + test::graph::Binary(g, "Add", test::graph::Constant(g, lhs), + test::graph::Constant(g, rhs)); + return g; +} + +#define BM_BCAST_ADD_ROW(DEVICE, R, C) \ + static void BM_##DEVICE##_BcastAddRow_R##R##_C##C(int iters, int arg) { \ + const int rows = RowsFromArg(arg); \ + const int cols = ColsFromArg(arg); \ + const int64 tot = static_cast(iters) * rows * cols; \ + testing::ItemsProcessed(tot); \ + testing::BytesProcessed(tot * sizeof(float)); \ + test::Benchmark(#DEVICE, BcastAdd(rows, cols, 0)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_BcastAddRow_R##R##_C##C)->Arg(RowsAndColsArg(R, C)); + +#define BM_BCAST_ADD_ROW_ALL(DEVICE) \ + BM_BCAST_ADD_ROW(DEVICE, 512, 2048); \ + BM_BCAST_ADD_ROW(DEVICE, 512, 4096); \ + BM_BCAST_ADD_ROW(DEVICE, 2048, 512); \ + BM_BCAST_ADD_ROW(DEVICE, 4096, 512); +BM_BCAST_ADD_ROW_ALL(cpu); +BM_BCAST_ADD_ROW_ALL(gpu); +#undef BM_BCAST_ADD_ROW_ALL +#undef BM_BCAST_ADD_ROW + +#define BM_BCAST_ADD_COL(DEVICE, R, C) \ + static void BM_##DEVICE##_BcastAddCol_R##R##_C##C(int iters, int arg) { \ + const int rows = RowsFromArg(arg); \ + const int cols = ColsFromArg(arg); \ + const int64 tot = static_cast(iters) * rows * cols; \ + testing::ItemsProcessed(tot); \ + testing::BytesProcessed(tot * sizeof(float)); \ + test::Benchmark(#DEVICE, BcastAdd(rows, cols, 1)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_BcastAddCol_R##R##_C##C)->Arg(RowsAndColsArg(R, C)); + +#define BM_BCAST_ADD_COL_ALL(DEVICE) \ + BM_BCAST_ADD_COL(DEVICE, 512, 2048); \ + BM_BCAST_ADD_COL(DEVICE, 512, 4096); \ + BM_BCAST_ADD_COL(DEVICE, 2048, 512); \ + BM_BCAST_ADD_COL(DEVICE, 4096, 512); +BM_BCAST_ADD_COL_ALL(cpu); +BM_BCAST_ADD_COL_ALL(gpu); +#undef BM_BCAST_ADD_COL_ALL +#undef BM_BCAST_ADD_COL + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc new file mode 100644 index 0000000000..0919bab96f --- /dev/null +++ b/tensorflow/core/kernels/decode_csv_op.cc @@ -0,0 +1,222 @@ +// See docs in ../ops/parsing_ops.cc. +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +class DecodeCSVOp : public OpKernel { + public: + explicit DecodeCSVOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string delim; + + OP_REQUIRES_OK(ctx, ctx->GetAttr("OUT_TYPE", &out_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("field_delim", &delim)); + + OP_REQUIRES(ctx, delim.size() == 1, + errors::InvalidArgument("field_delim should be only 1 char")); + + delim_ = delim[0]; + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* records; + OpInputList record_defaults; + + OP_REQUIRES_OK(ctx, ctx->input("records", &records)); + OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults)); + + for (int i = 0; i < record_defaults.size(); ++i) { + OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2, + errors::InvalidArgument( + "There should only be 1 default per field but field ", i, + " has ", record_defaults[i].NumElements())); + } + + auto records_t = records->flat(); + int records_size = records_t.size(); + + OpOutputList output; + OP_REQUIRES_OK(ctx, ctx->output_list("output", &output)); + + for (size_t i = 0; i < out_type_.size(); ++i) { + Tensor* out = nullptr; + output.allocate(i, records->shape(), &out); + } + + for (int i = 0; i < records_size; ++i) { + const StringPiece record(records_t(i)); + std::vector fields; + ExtractFields(ctx, record, &fields); + OP_REQUIRES(ctx, fields.size() == out_type_.size(), + errors::InvalidArgument("Expect ", out_type_.size(), + " fields but have ", fields.size(), + " in record ", i)); + + // Check each field in the record + for (size_t f = 0; f < out_type_.size(); ++f) { + const DataType& dtype = out_type_[f]; + switch (dtype) { + case DT_INT32: { + // If this field is empty, check if default is given: + // If yes, use default value; Otherwise report error. + if (fields[f].empty()) { + OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, + errors::InvalidArgument( + "Field ", f, + " is required but missing in record ", i, "!")); + + output[f]->flat()(i) = record_defaults[f].flat()(0); + } else { + int32 value; + OP_REQUIRES(ctx, strings::safe_strto32(fields[f].c_str(), &value), + errors::InvalidArgument("Field ", f, " in record ", i, + " is not a valid int32: ", + fields[f])); + output[f]->flat()(i) = value; + } + break; + } + case DT_INT64: { + // If this field is empty, check if default is given: + // If yes, use default value; Otherwise report error. + if (fields[f].empty()) { + OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, + errors::InvalidArgument( + "Field ", f, + " is required but missing in record ", i, "!")); + + output[f]->flat()(i) = record_defaults[f].flat()(0); + } else { + int64 value; + OP_REQUIRES(ctx, strings::safe_strto64(fields[f].c_str(), &value), + errors::InvalidArgument("Field ", f, " in record ", i, + " is not a valid int64: ", + fields[f])); + output[f]->flat()(i) = value; + } + break; + } + case DT_FLOAT: { + // If this field is empty, check if default is given: + // If yes, use default value; Otherwise report error. + if (fields[f].empty()) { + OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, + errors::InvalidArgument( + "Field ", f, + " is required but missing in record ", i, "!")); + output[f]->flat()(i) = record_defaults[f].flat()(0); + } else { + float value; + OP_REQUIRES(ctx, strings::safe_strtof(fields[f].c_str(), &value), + errors::InvalidArgument("Field ", f, " in record ", i, + " is not a valid float: ", + fields[f])); + output[f]->flat()(i) = value; + } + break; + } + case DT_STRING: { + // If this field is empty, check if default is given: + // If yes, use default value; Otherwise report error. + if (fields[f].empty()) { + OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, + errors::InvalidArgument( + "Field ", f, + " is required but missing in record ", i, "!")); + output[f]->flat()(i) = + record_defaults[f].flat()(0); + } else { + output[f]->flat()(i) = fields[f]; + } + break; + } + default: + OP_REQUIRES(ctx, false, + errors::InvalidArgument("csv: data type ", dtype, + " not supported in field ", f)); + } + } + } + } + + private: + std::vector out_type_; + char delim_; + + void ExtractFields(OpKernelContext* ctx, StringPiece input, + std::vector* result) { + int current_idx = 0; + if (!input.empty()) { + while (static_cast(current_idx) < input.size()) { + if (input[current_idx] == '\n' || input[current_idx] == '\r') { + current_idx++; + continue; + } + + bool quoted = false; + if (input[current_idx] == '"') { + quoted = true; + current_idx++; + } + + // This is the body of the field; + string field; + if (!quoted) { + while (static_cast(current_idx) < input.size() && + input[current_idx] != delim_) { + OP_REQUIRES(ctx, input[current_idx] != '"' && + input[current_idx] != '\n' && + input[current_idx] != '\r', + errors::InvalidArgument( + "Unquoted fields cannot have quotes/CRLFs inside")); + field += input[current_idx]; + current_idx++; + } + + // Go to next field or the end + current_idx++; + } else { + // Quoted field needs to be ended with '"' and delim or end + while ( + (static_cast(current_idx) < input.size() - 1) && + (input[current_idx] != '"' || input[current_idx + 1] != delim_)) { + if (input[current_idx] != '"') { + field += input[current_idx]; + current_idx++; + } else { + OP_REQUIRES( + ctx, input[current_idx + 1] == '"', + errors::InvalidArgument("Quote inside a string has to be " + "escaped by another quote")); + field += '"'; + current_idx += 2; + } + } + + OP_REQUIRES( + ctx, + input[current_idx] == '"' && + (static_cast(current_idx) == input.size() - 1 || + input[current_idx + 1] == delim_), + errors::InvalidArgument("Quoted field has to end with quote " + "followed by delim or end")); + + current_idx += 2; + } + + result->push_back(field); + } + + // Check if the last field is missing + if (input[input.size() - 1] == delim_) result->push_back(string()); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("DecodeCSV").Device(DEVICE_CPU), DecodeCSVOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/decode_jpeg_op.cc b/tensorflow/core/kernels/decode_jpeg_op.cc new file mode 100644 index 0000000000..e41d3f3e11 --- /dev/null +++ b/tensorflow/core/kernels/decode_jpeg_op.cc @@ -0,0 +1,72 @@ +// See docs in ../ops/image_ops.cc + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/lib/jpeg/jpeg_mem.h" + +namespace tensorflow { + +// Decode the contents of a JPEG file +class DecodeJpegOp : public OpKernel { + public: + explicit DecodeJpegOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("channels", &flags_.components)); + OP_REQUIRES(context, flags_.components == 0 || flags_.components == 1 || + flags_.components == 3, + errors::InvalidArgument("channels must be 0, 1, or 3, got ", + flags_.components)); + OP_REQUIRES_OK(context, context->GetAttr("ratio", &flags_.ratio)); + OP_REQUIRES(context, flags_.ratio == 1 || flags_.ratio == 2 || + flags_.ratio == 4 || flags_.ratio == 8, + errors::InvalidArgument("ratio must be 1, 2, 4, or 8, got ", + flags_.ratio)); + OP_REQUIRES_OK( + context, context->GetAttr("fancy_upscaling", &flags_.fancy_upscaling)); + OP_REQUIRES_OK(context, + context->GetAttr("try_recover_truncated", + &flags_.try_recover_truncated_jpeg)); + OP_REQUIRES_OK(context, context->GetAttr("acceptable_fraction", + &flags_.min_acceptable_fraction)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& contents = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()), + errors::InvalidArgument("contents must be scalar, got shape ", + contents.shape().ShortDebugString())); + const StringPiece input = contents.scalar()(); + OP_REQUIRES(context, input.size() <= std::numeric_limits::max(), + errors::InvalidArgument("JPEG contents are too large for int: ", + input.size())); + + // Decode image, allocating tensor once the image size is known + Tensor* output = NULL; + OP_REQUIRES( + context, + jpeg::Uncompress( + input.data(), input.size(), flags_, NULL, + [=, &output](int width, int height, int channels) -> uint8* { + Status status(context->allocate_output( + 0, TensorShape({height, width, channels}), &output)); + if (!status.ok()) { + VLOG(1) << status; + context->SetStatus(status); + return nullptr; + } + return output->flat().data(); + }), + errors::InvalidArgument("Invalid JPEG data, size ", input.size())); + } + + private: + jpeg::UncompressFlags flags_; +}; +REGISTER_KERNEL_BUILDER(Name("DecodeJpeg").Device(DEVICE_CPU), DecodeJpegOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/decode_png_op.cc b/tensorflow/core/kernels/decode_png_op.cc new file mode 100644 index 0000000000..e8071526f9 --- /dev/null +++ b/tensorflow/core/kernels/decode_png_op.cc @@ -0,0 +1,69 @@ +// See docs in ../ops/image_ops.cc + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/lib/png/png_io.h" + +namespace tensorflow { + +// Decode the contents of a PNG file +class DecodePngOp : public OpKernel { + public: + explicit DecodePngOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_)); + OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3 || + channels_ == 4, + errors::InvalidArgument("channels must be 0, 1, 3, or 4, got ", + channels_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& contents = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()), + errors::InvalidArgument("contents must be scalar, got shape ", + contents.shape().ShortDebugString())); + + // Start decoding image to get shape details + const StringPiece data = contents.scalar()(); + png::DecodeContext decode; + OP_REQUIRES( + context, png::CommonInitDecode(data, channels_, 8, &decode), + errors::InvalidArgument("Invalid PNG header, data size ", data.size())); + + // Verify that width and height don't overflow int + const int width = decode.width; + const int height = decode.height; + if (width != static_cast(decode.width) || + height != static_cast(decode.height)) { + png::CommonFreeDecode(&decode); + OP_REQUIRES(context, false, + errors::InvalidArgument("PNG size too large for int: ", + decode.width, " by ", decode.height)); + } + + // Allocate tensor + Tensor* output = nullptr; + const auto status = context->allocate_output( + 0, TensorShape({height, width, decode.channels}), &output); + if (!status.ok()) png::CommonFreeDecode(&decode); + OP_REQUIRES_OK(context, status); + + // Finish decoding image + OP_REQUIRES( + context, png::CommonFinishDecode(output->flat().data(), + decode.channels * width, &decode), + errors::InvalidArgument("Invalid PNG data, size ", data.size())); + } + + private: + int channels_; +}; +REGISTER_KERNEL_BUILDER(Name("DecodePng").Device(DEVICE_CPU), DecodePngOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/decode_raw_op.cc b/tensorflow/core/kernels/decode_raw_op.cc new file mode 100644 index 0000000000..ef24c333a4 --- /dev/null +++ b/tensorflow/core/kernels/decode_raw_op.cc @@ -0,0 +1,90 @@ +// See docs in ../ops/parse_ops.cc. + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +template +class DecodeRawOp : public OpKernel { + public: + explicit DecodeRawOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("little_endian", &little_endian_)); + OP_REQUIRES_OK(context, context->GetAttr("out_type", &out_type_)); + } + + void Compute(OpKernelContext* context) override { + const auto& input = context->input(0); + int str_size = -1; + auto flat_in = input.flat(); + for (int i = 0; i < flat_in.size(); ++i) { + const string& in_str = flat_in(i); + if (str_size == -1) { + str_size = in_str.size(); + } else { + OP_REQUIRES(context, str_size == in_str.size(), + errors::InvalidArgument( + "DecodeRaw requires input strings to all be the same " + "size, but element ", + i, " has size ", str_size, " != ", in_str.size())); + } + } + TensorShape out_shape = input.shape(); + if (str_size == -1) { // Empty input + out_shape.AddDim(1); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output("output", out_shape, + &output_tensor)); + return; + } + OP_REQUIRES( + context, str_size % sizeof(T) == 0, + errors::InvalidArgument("Input to DecodeRaw has length ", str_size, + " that is not a multiple of ", sizeof(T), + ", the size of ", DataTypeString(out_type_))); + const int added_dim = str_size / sizeof(T); + out_shape.AddDim(added_dim); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output("output", out_shape, &output_tensor)); + auto out = output_tensor->flat_inner_dims(); + DCHECK_EQ(flat_in.size(), out.dimensions()[0]); + OP_REQUIRES( + context, + little_endian_ == ::tensorflow::port::kLittleEndian || sizeof(T) == 1, + errors::Unimplemented("Unimplemented support for little_endian=", + little_endian_ ? "true" : "false")); + // Endianness matches, so just copy each string byte-for-byte. + T* out_data = out.data(); + for (int i = 0; i < flat_in.size(); ++i) { + const T* in_data = reinterpret_cast(flat_in(i).data()); + memcpy(out_data, in_data, str_size); + out_data += added_dim; + } + } + + private: + bool little_endian_; + DataType out_type_; +}; + +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("DecodeRaw").Device(DEVICE_CPU).TypeConstraint("out_type"), \ + DecodeRawOp) + +REGISTER(float); +REGISTER(double); +REGISTER(int32); +REGISTER(uint8); +REGISTER(int16); +REGISTER(int8); +REGISTER(int64); + +#undef REGISTER + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc new file mode 100644 index 0000000000..f56c37b4ef --- /dev/null +++ b/tensorflow/core/kernels/dense_update_ops.cc @@ -0,0 +1,136 @@ +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/assign_op.h" +#include "tensorflow/core/kernels/dense_update_ops.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +template +class AssignOpT : public AssignOp { + public: + using AssignOp::AssignOp; + + void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override { + functor::DenseUpdate copy; + copy(context->eigen_device(), lhs->flat(), rhs.flat()); + } +}; + +// TODO(jeff): Get rid of use_exclusive_lock_ option +template +class DenseUpdateOp : public OpKernel { + public: + explicit DenseUpdateOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("use_locking", &use_exclusive_lock_)); + const DataType dt = DataTypeToEnum::v(); + OP_REQUIRES_OK(context, context->MatchSignature({MakeRefType(dt), dt}, + {MakeRefType(dt)})); + } + + void Compute(OpKernelContext* context) override { + // We always return the input ref. + context->forward_ref_input_to_ref_output(0, 0); + + if (use_exclusive_lock_) { + mutex_lock l(*context->input_ref_mutex(0)); + DoUpdate(context); + } else { + DoUpdate(context); + } + } + + private: + void DoUpdate(OpKernelContext* context) { + Tensor Tparams = context->mutable_input(0, use_exclusive_lock_); + const Tensor& Tupdate = context->input(1); + OP_REQUIRES(context, Tparams.IsInitialized(), + errors::FailedPrecondition("Attempting to use uninitialized " + "parameters: ", + def().input(0))); + OP_REQUIRES( + context, Tparams.IsSameSize(Tupdate), + errors::InvalidArgument("Parameters and update must be the same size")); + + functor::DenseUpdate update_functor; + update_functor(context->eigen_device(), Tparams.flat(), + Tupdate.flat()); + } + + bool use_exclusive_lock_; +}; + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Assign").Device(DEVICE_CPU).TypeConstraint("T"), \ + AssignOpT); + +TF_CALL_ALL_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#if GOOGLE_CUDA +// Only register 'Assign' on GPU for the subset of types also supported by +// 'Variable' (see variable_ops.cc.) +#define REGISTER_GPU_KERNELS(type) \ + namespace functor { \ + template <> \ + void DenseUpdate::operator()( \ + const GPUDevice& d, typename TTypes::Flat lhs, \ + typename TTypes::ConstFlat rhs); \ + extern template struct DenseUpdate; \ + } \ + REGISTER_KERNEL_BUILDER( \ + Name("Assign").Device(DEVICE_GPU).TypeConstraint("T"), \ + AssignOpT); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +#undef REGISTER_GPU_KERNELS +#endif // GOOGLE_CUDA + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignAdd").Device(DEVICE_CPU).TypeConstraint("T"), \ + DenseUpdateOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignSub").Device(DEVICE_CPU).TypeConstraint("T"), \ + DenseUpdateOp); + +TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC_FOR_OP(T, OP) \ + template <> \ + void DenseUpdate::operator()( \ + const GPUDevice& d, typename TTypes::Flat params, \ + typename TTypes::ConstFlat update); \ + extern template struct DenseUpdate +#define DECLARE_GPU_SPEC(T) \ + DECLARE_GPU_SPEC_FOR_OP(T, DenseUpdateType::ADD); \ + DECLARE_GPU_SPEC_FOR_OP(T, DenseUpdateType::SUB) +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +#undef DECLARE_GPU_SPEC +#undef DECLARE_GPU_SPEC_FOR_OP +} // namespace functor + +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignAdd").Device(DEVICE_GPU).TypeConstraint("T"), \ + DenseUpdateOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignSub").Device(DEVICE_GPU).TypeConstraint("T"), \ + DenseUpdateOp); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +#undef REGISTER_GPU_KERNELS +#endif // end GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/dense_update_ops.h b/tensorflow/core/kernels/dense_update_ops.h new file mode 100644 index 0000000000..d32c9a4af2 --- /dev/null +++ b/tensorflow/core/kernels/dense_update_ops.h @@ -0,0 +1,43 @@ +#ifndef TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_ +#define TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +enum DenseUpdateType { ADD, SUB, ASSIGN }; + +namespace functor { + +template +struct DenseUpdate; + +template +struct DenseUpdate { + void operator()(const Device& d, typename TTypes::Flat params, + typename TTypes::ConstFlat update) { + params.device(d) += update; + } +}; + +template +struct DenseUpdate { + void operator()(const Device& d, typename TTypes::Flat params, + typename TTypes::ConstFlat update) { + params.device(d) -= update; + } +}; + +template +struct DenseUpdate { + void operator()(const Device& d, typename TTypes::Flat params, + typename TTypes::ConstFlat update) { + params.device(d) = update; + } +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_ diff --git a/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc b/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc new file mode 100644 index 0000000000..8e80901c71 --- /dev/null +++ b/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc @@ -0,0 +1,22 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/dense_update_ops.h" + +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::DenseUpdate; \ + template struct functor::DenseUpdate; \ + template struct functor::DenseUpdate; +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); +#undef DEFINE_GPU_KERNELS + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/determinant_op.cc b/tensorflow/core/kernels/determinant_op.cc new file mode 100644 index 0000000000..d34aab7a44 --- /dev/null +++ b/tensorflow/core/kernels/determinant_op.cc @@ -0,0 +1,66 @@ +// See docs in ../ops/linalg_ops.cc. +#include + +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/Eigen/LU" + +namespace tensorflow { + +template +class DeterminantOp : public LinearAlgebraOp { + public: + explicit DeterminantOp(OpKernelConstruction* context) + : LinearAlgebraOp(context) {} + ~DeterminantOp() override {} + + TensorShape GetOutputMatrixShape( + const TensorShape& input_matrix_shape) override { + return TensorShape({}); + } + + int64 GetCostPerUnit(const TensorShape& input_matrix_shape) override { + const int64 rows = input_matrix_shape.dim_size(0); + if (rows > (1LL << 20)) { + // A big number to cap the cost in case overflow. + return kint32max; + } else { + return rows * rows * rows; + } + } + + using typename LinearAlgebraOp::MatrixMap; + using + typename LinearAlgebraOp::ConstMatrixMap; + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& input, + MatrixMap* output) override { + OP_REQUIRES(context, input.rows() == input.cols(), + errors::InvalidArgument("Input matrix must be square.")); + Scalar determinant; + if (input.rows() == 0) { + // An empty matrix' determinant is defined to be 1. See + // wikipedia. + determinant = 1; + } else { + determinant = input.determinant(); + } + OP_REQUIRES(context, std::isfinite(determinant), + errors::Internal("The determinant is not finite.")); + (*output)(0, 0) = determinant; + } +}; + +REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp), float); +REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp), double); +REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp), + float); +REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp), + double); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/diag_op.cc b/tensorflow/core/kernels/diag_op.cc new file mode 100644 index 0000000000..83e39d33a9 --- /dev/null +++ b/tensorflow/core/kernels/diag_op.cc @@ -0,0 +1,93 @@ +// See docs in ../ops/array_ops.cc +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace { +template +class DiagonalGenerator { + public: + explicit DiagonalGenerator(const Tensor& diagonal) : diagonal_(diagonal) { + static_assert(DoubleNumDims == 2 * NumDims, + "The second size must be the double of the first size."); + CHECK_EQ(diagonal.dims(), NumDims); + } + T operator()( + const Eigen::array& coordinates) const { + Eigen::array index; + for (int i = 0; i < NumDims; ++i) { + if (coordinates[i] != coordinates[NumDims + i]) { + return T(0); + } + index[i] = coordinates[i]; + } + return diagonal_.tensor()(index); + } + + private: + Tensor diagonal_; +}; +} // namespace + +// Generate the diagonal tensor with the diagonal set to the input tensor. +// It only allows up to rank 3 input tensor, so the output tensor is up to +// rank 6. +template +class DiagOp : public OpKernel { + public: + explicit DiagOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& diagonal = context->input(0); + const int num_dims = diagonal.dims(); + OP_REQUIRES(context, 1 <= num_dims, + errors::InvalidArgument( + "The rank of the diagonal should be between 1 and 3.")); + OP_REQUIRES(context, 3 >= num_dims, + errors::InvalidArgument( + "The rank of the diagonal should be between 1 and 3.")); + TensorShape out_shape; + for (int i = 0; i < num_dims; ++i) { + out_shape.AddDim(diagonal.dim_size(i)); + } + for (int i = 0; i < num_dims; ++i) { + out_shape.AddDim(diagonal.dim_size(i)); + } + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, out_shape, &output_tensor)); + switch (num_dims) { + case 1: + output_tensor->tensor() = output_tensor->tensor().generate( + DiagonalGenerator(diagonal)); + break; + case 2: + output_tensor->tensor() = output_tensor->tensor().generate( + DiagonalGenerator(diagonal)); + break; + case 3: + output_tensor->tensor() = output_tensor->tensor().generate( + DiagonalGenerator(diagonal)); + break; + default: + context->SetStatus(errors::Unimplemented( + "Diagonal of rank ", num_dims, " tensor is not supported yet.")); + return; + } + } +}; + +#define REGISTER_DIAGOP(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Diag").Device(DEVICE_CPU).TypeConstraint("T"), DiagOp) + +REGISTER_DIAGOP(double); +REGISTER_DIAGOP(float); +REGISTER_DIAGOP(int32); +REGISTER_DIAGOP(int64); + +#undef REGISTER_DIAGOP +} // namespace tensorflow diff --git a/tensorflow/core/kernels/dynamic_partition_op.cc b/tensorflow/core/kernels/dynamic_partition_op.cc new file mode 100644 index 0000000000..f1b44861b5 --- /dev/null +++ b/tensorflow/core/kernels/dynamic_partition_op.cc @@ -0,0 +1,154 @@ +// See docs in ../ops/data_flow_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +// Shared code that is not dependent on the type of T. We do this to reduce +// code size by not duplicating all this for all T (float, double, int32, etc.) +class DynamicPartitionOp_Shared : public OpKernel { + public: + explicit DynamicPartitionOp_Shared(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("num_partitions", &num_partitions_)); + // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, etc. + // to input[1]. Should we have the framework do some sort of + // integer promotion automatically, or should that be something + // that users have to do explicitly with a conversion operator + // in the graph? + } + + void ValidateAndAllocateOutputs(OpKernelContext* c, const Tensor** data, + const Tensor** partitions, + OpOutputList* Tout) { + OP_REQUIRES_OK(c, c->input("data", data)); + OP_REQUIRES_OK(c, c->input("partitions", partitions)); + OP_REQUIRES(c, TensorShapeUtils::StartsWith((*data)->shape(), + (*partitions)->shape()), + errors::InvalidArgument( + "data.shape must start with partitions.shape, ", + "got data.shape = ", (*data)->shape().ShortDebugString(), + ", partitions.shape = ", + (*partitions)->shape().ShortDebugString())); + + // Count how many occurrences of each partition id we have in partitions + gtl::InlinedVector partition_count(num_partitions_); + auto e_partitions = (*partitions)->flat(); + const int64 N = e_partitions.dimension(0); + for (int64 i = 0; i < N; i++) { + const int32 p = e_partitions(i); + OP_REQUIRES(c, p >= 0 && p < num_partitions_, + errors::InvalidArgument( + "partitions", SliceString((*partitions)->shape(), i), + " = ", p, " is not in [0, ", num_partitions_, ")")); + partition_count[p]++; + } + + // Allocate output tensors of the right size + OP_REQUIRES_OK(c, c->output_list("outputs", Tout)); + for (int p = 0; p < num_partitions_; p++) { + TensorShape shape; + shape.AddDim(partition_count[p]); + for (int i = (*partitions)->dims(); i < (*data)->dims(); i++) { + shape.AddDim((*data)->dim_size(i)); + } + Tensor* out; + OP_REQUIRES_OK(c, Tout->allocate(p, shape, &out)); + } + } + + protected: + int num_partitions_; + + static string SliceString(const TensorShape& shape, const int64 flat) { + // Special case rank 0 and 1 + const int dims = shape.dims(); + if (dims == 0) return ""; + if (dims == 1) return strings::StrCat("[", flat, "]"); + + // Compute strides + gtl::InlinedVector strides(dims); + strides.back() = 1; + for (int i = dims - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * shape.dim_size(i + 1); + } + + // Unflatten index + int64 left = flat; + string result; + for (int i = 0; i < dims; i++) { + strings::StrAppend(&result, i ? "," : "[", left / strides[i]); + left %= strides[i]; + } + strings::StrAppend(&result, "]"); + return result; + } +}; + +template +class DynamicPartitionOp : public DynamicPartitionOp_Shared { + public: + explicit DynamicPartitionOp(OpKernelConstruction* c) + : DynamicPartitionOp_Shared(c) {} + void Compute(OpKernelContext* c) override { + const Tensor* data; + const Tensor* partitions; + OpOutputList outputs; + ValidateAndAllocateOutputs(c, &data, &partitions, &outputs); + if (!c->status().ok()) return; + if (num_partitions_ == 0 || data->NumElements() == 0) return; + + auto e_partitions = partitions->flat(); + const int64 N = e_partitions.dimension(0); + gtl::InlinedVector output_index(num_partitions_); + + if (partitions->dims() == data->dims()) { + // Walk through data and copy the data to the appropriate output tensor + const auto data_flat = data->flat(); + std::vector, + Eigen::Aligned> > out_vec; + for (int p = 0; p < num_partitions_; p++) { + out_vec.push_back(outputs[p]->vec()); + } + for (int64 i = 0; i < N; i++) { + const int32 p = e_partitions(i); + out_vec[p](output_index[p]) = data_flat(i); + output_index[p]++; + } + } else { + // If data has extra dimensions, use Eigen slices + std::vector, + Eigen::Aligned> > out_flat; + for (int p = 0; p < num_partitions_; p++) { + out_flat.push_back(outputs[p]->flat_outer_dims()); + } + + // Walk through data and copy the data to the appropriate output tensor + const int64 slice_size = data->NumElements() / N; + const auto data_flat = data->shaped({N, slice_size}); + Eigen::DSizes sizes(1, slice_size); + for (int64 i = 0; i < N; i++) { + const int32 p = e_partitions(i); + // outputs[p][output_index[p]++] = data[i] + Eigen::DSizes out_indices(output_index[p], 0); + Eigen::DSizes data_indices(i, 0); + out_flat[p].slice(out_indices, sizes) = + data_flat.slice(data_indices, sizes); + output_index[p]++; + } + } + } +}; + +#define REGISTER_DYNAMIC_PARTITION(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("DynamicPartition").Device(DEVICE_CPU).TypeConstraint("T"), \ + DynamicPartitionOp) + +TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION); +#undef REGISTER_DYNAMIC_PARTITION + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/dynamic_partition_op_test.cc b/tensorflow/core/kernels/dynamic_partition_op_test.cc new file mode 100644 index 0000000000..b0e5e7deb0 --- /dev/null +++ b/tensorflow/core/kernels/dynamic_partition_op_test.cc @@ -0,0 +1,145 @@ +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/public/tensor.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +class DynamicPartitionOpTest : public OpsTestBase { + protected: + void MakeOp() { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "DynamicPartition") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT32)) + .Attr("num_partitions", 4) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(DynamicPartitionOpTest, Simple_OneD) { + MakeOp(); + + // Similar to how we would use this to split embedding ids to be looked up + + // Feed and run + AddInputFromArray(TensorShape({6}), {0, 13, 2, 39, 4, 17}); + AddInputFromArray(TensorShape({6}), {0, 0, 2, 3, 2, 1}); + ASSERT_OK(RunOpKernel()); + + // Check the output sizes + { // Output 0 + Tensor expected(allocator(), DT_FLOAT, TensorShape({2})); + test::FillValues(&expected, {0, 13}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + } + { // Output 1 + Tensor expected(allocator(), DT_FLOAT, TensorShape({1})); + test::FillValues(&expected, {17}); + test::ExpectTensorEqual(expected, *GetOutput(1)); + } + { // Output 2 + Tensor expected(allocator(), DT_FLOAT, TensorShape({2})); + test::FillValues(&expected, {2, 4}); + test::ExpectTensorEqual(expected, *GetOutput(2)); + } + { // Output 3 + Tensor expected(allocator(), DT_FLOAT, TensorShape({1})); + test::FillValues(&expected, {39}); + test::ExpectTensorEqual(expected, *GetOutput(3)); + } +} + +TEST_F(DynamicPartitionOpTest, Simple_TwoD) { + MakeOp(); + + // Feed and run + AddInputFromArray( + TensorShape({6, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}); + AddInputFromArray(TensorShape({6}), {0, 0, 2, 3, 2, 1}); + ASSERT_OK(RunOpKernel()); + + // Check the output sizes + { // Output 0 + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + test::FillValues(&expected, {0, 1, 2, 3, 4, 5}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + } + { // Output 1 + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 3})); + test::FillValues(&expected, {15, 16, 17}); + test::ExpectTensorEqual(expected, *GetOutput(1)); + } + { // Output 2 + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + test::FillValues(&expected, {6, 7, 8, 12, 13, 14}); + test::ExpectTensorEqual(expected, *GetOutput(2)); + } + { // Output 3 + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 3})); + test::FillValues(&expected, {9, 10, 11}); + test::ExpectTensorEqual(expected, *GetOutput(3)); + } +} + +TEST_F(DynamicPartitionOpTest, SomeOutputsEmpty) { + MakeOp(); + + // Feed and run + AddInputFromArray(TensorShape({6}), {0, 13, 2, 39, 4, 17}); + AddInputFromArray(TensorShape({6}), {0, 0, 2, 2, 0, 2}); + ASSERT_OK(RunOpKernel()); + + TensorShape empty_one_dim; + empty_one_dim.AddDim(0); + Tensor expected_empty(allocator(), DT_FLOAT, empty_one_dim); + + // Check the output sizes + { // Output 0 + Tensor expected(allocator(), DT_FLOAT, TensorShape({3})); + test::FillValues(&expected, {0, 13, 4}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + } + { // Output 1 + test::ExpectTensorEqual(expected_empty, *GetOutput(1)); + } + { // Output 2 + Tensor expected(allocator(), DT_FLOAT, TensorShape({3})); + test::FillValues(&expected, {2, 39, 17}); + test::ExpectTensorEqual(expected, *GetOutput(2)); + } + { // Output 3 + test::ExpectTensorEqual(expected_empty, *GetOutput(3)); + } +} + +TEST_F(DynamicPartitionOpTest, Error_IndexOutOfRange) { + MakeOp(); + + // Feed and run + AddInputFromArray(TensorShape({5, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + AddInputFromArray(TensorShape({5}), {0, 2, 99, 2, 2}); + Status s = RunOpKernel(); + EXPECT_TRUE( + StringPiece(s.ToString()).contains("partitions[2] = 99 is not in [0, 4)")) + << s; +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc new file mode 100644 index 0000000000..a5623685fb --- /dev/null +++ b/tensorflow/core/kernels/dynamic_stitch_op.cc @@ -0,0 +1,158 @@ +// See docs in ../ops/data_flow_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +template +class DynamicStitchOp : public OpKernel { + public: + explicit DynamicStitchOp(OpKernelConstruction* c) : OpKernel(c) { + // Compute expected input signature + const DataType dt = DataTypeToEnum::v(); + const int n = c->num_inputs() / 2; + DataTypeVector expected; + for (int i = 0; i < n; i++) { + expected.push_back(DT_INT32); + } + for (int i = 0; i < n; i++) { + expected.push_back(dt); + } + OP_REQUIRES_OK(c, c->MatchSignature(expected, {dt})); + OP_REQUIRES( + c, c->num_inputs() > 0, + errors::InvalidArgument("DynamicStitchOp: Must have some inputs")); + OP_REQUIRES(c, c->num_inputs() % 2 == 0, + errors::InvalidArgument( + "DynamicStitchOp: Must have even number of arguments")); + } + + void Compute(OpKernelContext* c) override { + // Find maximum index in the indices vectors + OpInputList indices_inputs; + OP_REQUIRES_OK(c, c->input_list("indices", &indices_inputs)); + + int32 max_index = -1; + for (const Tensor& indices : indices_inputs) { + Eigen::Tensor m = + indices.flat().maximum(); + max_index = std::max(m(), max_index); + } + const int first_dim_size = max_index + 1; + + // Validate that data[i].shape = indices[i].shape + constant + OpInputList data_inputs; + OP_REQUIRES_OK(c, c->input_list("data", &data_inputs)); + const Tensor& data0 = data_inputs[0]; + const Tensor& indices0 = indices_inputs[0]; + for (int input_num = 0; input_num < indices_inputs.size(); input_num++) { + const Tensor& indices = indices_inputs[input_num]; + const Tensor& data = data_inputs[input_num]; + OP_REQUIRES( + c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()), + errors::InvalidArgument( + "data[", input_num, "].shape = ", data.shape().ShortDebugString(), + " does not start with indices[", input_num, "].shape = ", + indices.shape().ShortDebugString())); + OP_REQUIRES( + c, input_num == 0 || SameExtraShape(data0, indices0, data, indices), + errors::InvalidArgument( + "Need data[0].shape[", indices0.dims(), ":] = data[", input_num, + "].shape[", indices.dims(), ":], got data[0].shape = ", + data0.shape().ShortDebugString(), ", data[", input_num, + "].shape = ", data.shape().ShortDebugString(), + ", indices[0].shape = ", indices0.shape().ShortDebugString(), + ", indices[", input_num, "].shape = ", + indices.shape().ShortDebugString())); + } + + // Allocate result tensor of shape + // [first_dim_size] + data.shape[indices.dims:] + TensorShape result_shape; + result_shape.AddDim(first_dim_size); + for (int d = indices0.dims(); d < data0.dims(); d++) { + result_shape.AddDim(data0.dim_size(d)); + } + Tensor* merged = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &merged)); + + // TODO(jeff): Currently we leave uninitialized any portions of + // merged that aren't covered by an index in indices. What should we do? + if (first_dim_size > 0) { + auto merged_flat = merged->flat_outer_dims(); + const int slice_size = merged_flat.dimension(1); + for (int input_num = 0; input_num < indices_inputs.size(); input_num++) { + const Tensor& indices = indices_inputs[input_num]; + auto indices_vec = indices.flat(); + const Tensor& data = data_inputs[input_num]; + auto data_flat = + data.shaped({indices_vec.dimension(0), slice_size}); + + if (DataTypeCanUseMemcpy(DataTypeToEnum::v())) { + T* merged_base = &merged_flat(0, 0); + const T* data_base = &data_flat(0, 0); + const size_t slice_bytes = slice_size * sizeof(T); + for (int i = 0; i < indices_vec.size(); i++) { + memcpy(merged_base + indices_vec(i) * slice_size, + data_base + i * slice_size, slice_bytes); + } + } else { + Eigen::DSizes sizes(1, slice_size); + for (int i = 0; i < indices_vec.size(); i++) { + // Copy slice data[i] to merged[indices[i]] + Eigen::DSizes data_indices(i, 0); + Eigen::DSizes merged_indices(indices_vec(i), + 0); + merged_flat.slice(merged_indices, sizes) = + data_flat.slice(data_indices, sizes); + } + } + } + } + } + + private: + // Check if data0.shape[indices0.dims():] == data1.shape[indices1.dims():] + static bool SameExtraShape(const Tensor& data0, const Tensor& indices0, + const Tensor& data1, const Tensor& indices1) { + const int extra0 = data0.dims() - indices0.dims(); + const int extra1 = data1.dims() - indices1.dims(); + if (extra0 != extra1) return false; + for (int i = 0; i < extra0; i++) { + if (data0.dim_size(indices0.dims() + i) != + data1.dim_size(indices1.dims() + i)) { + return false; + } + } + return true; + } +}; + +#define REGISTER_DYNAMIC_STITCH(type) \ + REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("indices"), \ + DynamicStitchOp) + +TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_STITCH); +#undef REGISTER_DYNAMIC_STITCH + +#if GOOGLE_CUDA +#define REGISTER_DYNAMIC_STITCH_GPU(type) \ + REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("indices") \ + .HostMemory("data") \ + .HostMemory("merged"), \ + DynamicStitchOp) + +TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_STITCH_GPU); +#undef REGISTER_DYNAMIC_STITCH_GPU + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/dynamic_stitch_op_test.cc b/tensorflow/core/kernels/dynamic_stitch_op_test.cc new file mode 100644 index 0000000000..8c71f0fd0f --- /dev/null +++ b/tensorflow/core/kernels/dynamic_stitch_op_test.cc @@ -0,0 +1,133 @@ +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +class DynamicStitchOpTest : public OpsTestBase { + protected: + void MakeOp(int n, DataType dt) { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "DynamicStitch") + .Input(FakeInput(n, DT_INT32)) + .Input(FakeInput(n, dt)) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(DynamicStitchOpTest, Simple_OneD) { + MakeOp(2, DT_FLOAT); + + // Feed and run + AddInputFromArray(TensorShape({3}), {0, 4, 7}); + AddInputFromArray(TensorShape({5}), {1, 6, 2, 3, 5}); + AddInputFromArray(TensorShape({3}), {0, 40, 70}); + AddInputFromArray(TensorShape({5}), {10, 60, 20, 30, 50}); + ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({8})); + test::FillValues(&expected, {0, 10, 20, 30, 40, 50, 60, 70}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(DynamicStitchOpTest, Simple_TwoD) { + MakeOp(3, DT_FLOAT); + + // Feed and run + AddInputFromArray(TensorShape({3}), {0, 4, 7}); + AddInputFromArray(TensorShape({2}), {1, 6}); + AddInputFromArray(TensorShape({3}), {2, 3, 5}); + AddInputFromArray(TensorShape({3, 2}), {0, 1, 40, 41, 70, 71}); + AddInputFromArray(TensorShape({2, 2}), {10, 11, 60, 61}); + AddInputFromArray(TensorShape({3, 2}), {20, 21, 30, 31, 50, 51}); + ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({8, 2})); + test::FillValues(&expected, {0, 1, 10, 11, 20, 21, 30, 31, 40, 41, 50, + 51, 60, 61, 70, 71}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(DynamicStitchOpTest, Error_IndicesMultiDimensional) { + MakeOp(2, DT_FLOAT); + + // Feed and run + AddInputFromArray(TensorShape({3}), {0, 4, 7}); + AddInputFromArray(TensorShape({1, 5}), {1, 6, 2, 3, 5}); + AddInputFromArray(TensorShape({3}), {0, 40, 70}); + AddInputFromArray(TensorShape({5}), {10, 60, 20, 30, 50}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("data[1].shape = [5] does not start with " + "indices[1].shape = [1,5]")) + << s; +} + +TEST_F(DynamicStitchOpTest, Error_DataNumDimsMismatch) { + MakeOp(2, DT_FLOAT); + + // Feed and run + AddInputFromArray(TensorShape({3}), {0, 4, 7}); + AddInputFromArray(TensorShape({5}), {1, 6, 2, 3, 5}); + AddInputFromArray(TensorShape({3}), {0, 40, 70}); + AddInputFromArray(TensorShape({1, 5}), {10, 60, 20, 30, 50}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("data[1].shape = [1,5] does not start with " + "indices[1].shape = [5]")) + << s; +} + +TEST_F(DynamicStitchOpTest, Error_DataDimSizeMismatch) { + MakeOp(2, DT_FLOAT); + + // Feed and run + AddInputFromArray(TensorShape({3}), {0, 4, 5}); + AddInputFromArray(TensorShape({4}), {1, 6, 2, 3}); + AddInputFromArray(TensorShape({3, 1}), {0, 40, 70}); + AddInputFromArray(TensorShape({4, 2}), + {10, 11, 60, 61, 20, 21, 30, 31}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("Need data[0].shape[1:] = data[1].shape[1:], " + "got data[0].shape = [3,1], data[1].shape = [4,2]")) + << s; +} + +TEST_F(DynamicStitchOpTest, Error_DataAndIndicesSizeMismatch) { + MakeOp(2, DT_FLOAT); + + // Feed and run + AddInputFromArray(TensorShape({3}), {0, 4, 7}); + AddInputFromArray(TensorShape({5}), {1, 6, 2, 3, 5}); + AddInputFromArray(TensorShape({3}), {0, 40, 70}); + AddInputFromArray(TensorShape({4}), {10, 60, 20, 30}); + Status s = RunOpKernel(); + EXPECT_TRUE( + StringPiece(s.ToString()) + .contains( + "data[1].shape = [4] does not start with indices[1].shape = [5]")) + << s; +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/edit_distance_op.cc b/tensorflow/core/kernels/edit_distance_op.cc new file mode 100644 index 0000000000..938d7f056b --- /dev/null +++ b/tensorflow/core/kernels/edit_distance_op.cc @@ -0,0 +1,217 @@ +// See docs in ../ops/array_ops.cc. + +#define EIGEN_USE_THREADS + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/gtl/edit_distance.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace { + +Status ValidateShapes(OpKernelContext* ctx, const Tensor& hypothesis_indices, + const Tensor& hypothesis_values, + const Tensor& hypothesis_shape, + const Tensor& truth_indices, const Tensor& truth_values, + const Tensor& truth_shape) { + if (!TensorShapeUtils::IsMatrix(hypothesis_indices.shape())) + return errors::InvalidArgument( + "hypothesis_indices should be a matrix, but got shape: ", + hypothesis_indices.shape().DebugString()); + if (!TensorShapeUtils::IsMatrix(truth_indices.shape())) + return errors::InvalidArgument( + "truth_indices should be a matrix, but got shape: ", + truth_indices.shape().DebugString()); + if (!TensorShapeUtils::IsVector(hypothesis_values.shape())) + return errors::InvalidArgument( + "hypothesis_values should be a vector, but got shape: ", + hypothesis_values.shape().DebugString()); + if (!TensorShapeUtils::IsVector(truth_values.shape())) + return errors::InvalidArgument( + "truth_values should be a vector, but got shape: ", + truth_values.shape().DebugString()); + if (!TensorShapeUtils::IsVector(hypothesis_shape.shape())) + return errors::InvalidArgument( + "hypothesis_shape should be a vector, but got shape: ", + hypothesis_shape.shape().DebugString()); + if (!TensorShapeUtils::IsVector(truth_shape.shape())) + return errors::InvalidArgument( + "truth_shape should be a vector, but got shape: ", + truth_shape.shape().DebugString()); + if (hypothesis_shape.NumElements() != hypothesis_indices.dim_size(1)) + return errors::InvalidArgument( + "Expected hypothesis_shape.NumElements == " + "#cols(hypothesis_indices), their shapes are: ", + hypothesis_shape.shape().DebugString(), " and ", + hypothesis_indices.shape().DebugString()); + if (truth_shape.NumElements() < 2) + return errors::InvalidArgument( + "Input SparseTensors must have rank at least 2, but truth_shape " + "rank is: ", + truth_shape.NumElements()); + if (truth_shape.NumElements() != truth_indices.dim_size(1)) + return errors::InvalidArgument( + "Expected truth_shape.NumElements == " + "#cols(truth_indices), their shapes are: ", + truth_shape.shape().DebugString(), " and ", + truth_indices.shape().DebugString()); + if (truth_shape.NumElements() != hypothesis_shape.NumElements()) + return errors::InvalidArgument( + "Expected truth and hypothesis to have matching ranks, but " + "their shapes are: ", + truth_shape.shape().DebugString(), " and ", + hypothesis_shape.shape().DebugString()); + + return Status::OK(); +} + +} // namespace + +template +class EditDistanceOp : public OpKernel { + public: + explicit EditDistanceOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("normalize", &normalize_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* hypothesis_indices; + const Tensor* hypothesis_values; + const Tensor* hypothesis_shape; + const Tensor* truth_indices; + const Tensor* truth_values; + const Tensor* truth_shape; + OP_REQUIRES_OK(ctx, ctx->input("hypothesis_indices", &hypothesis_indices)); + OP_REQUIRES_OK(ctx, ctx->input("hypothesis_values", &hypothesis_values)); + OP_REQUIRES_OK(ctx, ctx->input("hypothesis_shape", &hypothesis_shape)); + OP_REQUIRES_OK(ctx, ctx->input("truth_indices", &truth_indices)); + OP_REQUIRES_OK(ctx, ctx->input("truth_values", &truth_values)); + OP_REQUIRES_OK(ctx, ctx->input("truth_shape", &truth_shape)); + + OP_REQUIRES_OK( + ctx, ValidateShapes(ctx, *hypothesis_indices, *hypothesis_values, + *hypothesis_shape, *truth_indices, *truth_values, + *truth_shape)); + + TensorShape hypothesis_st_shape = TensorShapeUtils::MakeShape( + hypothesis_shape->vec().data(), hypothesis_shape->NumElements()); + TensorShape truth_st_shape = TensorShapeUtils::MakeShape( + truth_shape->vec().data(), truth_shape->NumElements()); + + // Assume indices are sorted in row-major order. + std::vector sorted_order(truth_st_shape.dims()); + std::iota(sorted_order.begin(), sorted_order.end(), 0); + + sparse::SparseTensor hypothesis(*hypothesis_indices, *hypothesis_values, + hypothesis_st_shape, sorted_order); + sparse::SparseTensor truth(*truth_indices, *truth_values, truth_st_shape, + sorted_order); + + // Group dims 0, 1, ..., RANK - 1. The very last dim is assumed + // to store the variable length sequences. + std::vector group_dims(truth_st_shape.dims() - 1); + std::iota(group_dims.begin(), group_dims.end(), 0); + + TensorShape output_shape; + for (int d = 0; d < group_dims.size(); ++d) { + output_shape.AddDim(std::max(hypothesis_st_shape.dim_size(d), + truth_st_shape.dim_size(d))); + } + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("output", output_shape, &output)); + auto output_t = output->flat(); + output_t.setZero(); + + std::vector output_strides(output_shape.dims()); + output_strides[output_shape.dims() - 1] = 1; + for (int d = output_shape.dims() - 2; d >= 0; --d) { + output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1); + } + + auto hypothesis_grouper = hypothesis.group(group_dims); + auto truth_grouper = truth.group(group_dims); + + auto hypothesis_iter = hypothesis_grouper.begin(); + auto truth_iter = truth_grouper.begin(); + + auto cmp = std::equal_to(); + + while (hypothesis_iter != hypothesis_grouper.end() && + truth_iter != truth_grouper.end()) { + sparse::Group truth_i = *truth_iter; + sparse::Group hypothesis_j = *hypothesis_iter; + std::vector g_truth = truth_i.group(); + std::vector g_hypothesis = hypothesis_j.group(); + auto truth_seq = truth_i.values(); + auto hypothesis_seq = hypothesis_j.values(); + + if (g_truth == g_hypothesis) { + auto loc = std::inner_product(g_truth.begin(), g_truth.end(), + output_strides.begin(), 0); + output_t(loc) = + gtl::LevenshteinDistance(truth_seq, hypothesis_seq, cmp); + if (normalize_) output_t(loc) /= truth_seq.size(); + + ++hypothesis_iter; + ++truth_iter; + } else if (g_truth > g_hypothesis) { // missing truth @ this hypothesis + auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(), + output_strides.begin(), 0); + output_t(loc) = hypothesis_seq.size(); + if (normalize_) output_t(loc) /= 0.0; + ++hypothesis_iter; + } else { // missing hypothesis @ this truth + auto loc = std::inner_product(g_truth.begin(), g_truth.end(), + output_strides.begin(), 0); + output_t(loc) = (normalize_) ? 1.0 : truth_seq.size(); + ++truth_iter; + } + } + while (hypothesis_iter != hypothesis_grouper.end()) { // missing truths + sparse::Group hypothesis_j = *hypothesis_iter; + std::vector g_hypothesis = hypothesis_j.group(); + auto hypothesis_seq = hypothesis_j.values(); + auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(), + output_strides.begin(), 0); + output_t(loc) = hypothesis_seq.size(); + if (normalize_) output_t(loc) /= 0.0; + ++hypothesis_iter; + } + while (truth_iter != truth_grouper.end()) { // missing hypotheses + sparse::Group truth_i = *truth_iter; + std::vector g_truth = truth_i.group(); + auto truth_seq = truth_i.values(); + auto loc = std::inner_product(g_truth.begin(), g_truth.end(), + output_strides.begin(), 0); + output_t(loc) = (normalize_) ? 1.0 : truth_seq.size(); + ++truth_iter; + } + } + + private: + bool normalize_; + + TF_DISALLOW_COPY_AND_ASSIGN(EditDistanceOp); +}; + +#define REGISTER_CPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("EditDistance").Device(DEVICE_CPU).TypeConstraint("T"), \ + EditDistanceOp); + +TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL); + +#undef REGISTER_CPU_KERNEL + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/encode_jpeg_op.cc b/tensorflow/core/kernels/encode_jpeg_op.cc new file mode 100644 index 0000000000..8f5fd2f8be --- /dev/null +++ b/tensorflow/core/kernels/encode_jpeg_op.cc @@ -0,0 +1,114 @@ +// See docs in ../ops/image_ops.cc + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/lib/jpeg/jpeg_mem.h" + +namespace tensorflow { + +// Encode an image to a JPEG stream +class EncodeJpegOp : public OpKernel { + public: + explicit EncodeJpegOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("format", &format_)); + if (format_.empty()) { + flags_.format = static_cast(0); + } else if (format_ == "grayscale") { + flags_.format = jpeg::FORMAT_GRAYSCALE; + } else if (format_ == "rgb") { + flags_.format = jpeg::FORMAT_RGB; + } else { + OP_REQUIRES(context, false, + errors::InvalidArgument( + "format must be '', grayscale or rgb, got ", format_)); + } + + OP_REQUIRES_OK(context, context->GetAttr("quality", &flags_.quality)); + OP_REQUIRES(context, 0 <= flags_.quality && flags_.quality <= 100, + errors::InvalidArgument("quality must be in [0,100], got ", + flags_.quality)); + OP_REQUIRES_OK(context, + context->GetAttr("progressive", &flags_.progressive)); + OP_REQUIRES_OK( + context, context->GetAttr("optimize_size", &flags_.optimize_jpeg_size)); + OP_REQUIRES_OK(context, context->GetAttr("chroma_downsampling", + &flags_.chroma_downsampling)); + OP_REQUIRES_OK(context, context->GetAttr("chroma_downsampling", + &flags_.chroma_downsampling)); + + string density_unit; + OP_REQUIRES_OK(context, context->GetAttr("density_unit", &density_unit)); + if (density_unit == "in") { + flags_.density_unit = 1; + } else if (density_unit == "cm") { + flags_.density_unit = 2; + } else { + OP_REQUIRES(context, false, + errors::InvalidArgument("density_unit must be 'in' or 'cm'", + density_unit)); + } + + OP_REQUIRES_OK(context, context->GetAttr("x_density", &flags_.x_density)); + OP_REQUIRES_OK(context, context->GetAttr("y_density", &flags_.y_density)); + OP_REQUIRES_OK(context, context->GetAttr("xmp_metadata", &xmp_metadata_)); + flags_.xmp_metadata = xmp_metadata_; // StringPiece doesn't own data + } + + void Compute(OpKernelContext* context) override { + const Tensor& image = context->input(0); + OP_REQUIRES(context, image.dims() == 3, + errors::InvalidArgument("image must be 3-dimensional", + image.shape().ShortDebugString())); + + // Autodetect format if desired, otherwise make sure format and + // image channels are consistent. + int channels; + jpeg::CompressFlags adjusted_flags = flags_; + if (flags_.format == 0) { + channels = image.dim_size(2); + if (channels == 1) { + adjusted_flags.format = jpeg::FORMAT_GRAYSCALE; + } else if (channels == 3) { + adjusted_flags.format = jpeg::FORMAT_RGB; + } else { + OP_REQUIRES(context, false, errors::InvalidArgument( + "image must have 1 or 3 channels, got ", + image.shape().ShortDebugString())); + } + } else { + if (flags_.format == jpeg::FORMAT_GRAYSCALE) { + channels = 1; + } else { // RGB + channels = 3; + } + OP_REQUIRES(context, channels == image.dim_size(2), + errors::InvalidArgument("format ", format_, " expects ", + channels, " channels, got ", + image.shape().ShortDebugString())); + } + + // Encode image to jpeg string + Tensor* output = NULL; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &output)); + OP_REQUIRES(context, + jpeg::Compress(image.flat().data(), image.dim_size(1), + image.dim_size(0), adjusted_flags, + &output->scalar()()), + errors::Internal("JPEG encoding failed")); + } + + private: + string format_; + string xmp_metadata_; // Owns data referenced by flags_ + jpeg::CompressFlags flags_; +}; +REGISTER_KERNEL_BUILDER(Name("EncodeJpeg").Device(DEVICE_CPU), EncodeJpegOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/encode_png_op.cc b/tensorflow/core/kernels/encode_png_op.cc new file mode 100644 index 0000000000..5249074377 --- /dev/null +++ b/tensorflow/core/kernels/encode_png_op.cc @@ -0,0 +1,52 @@ +// See docs in ../ops/image_ops.cc + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/lib/png/png_io.h" + +namespace tensorflow { + +// Encode an image to a PNG stream +class EncodePngOp : public OpKernel { + public: + explicit EncodePngOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("compression", &compression_)); + OP_REQUIRES(context, -1 <= compression_ && compression_ <= 9, + errors::InvalidArgument("compression should be in [-1,9], got ", + compression_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& image = context->input(0); + OP_REQUIRES(context, image.dims() == 3, + errors::InvalidArgument("image must be 3-dimensional", + image.shape().ShortDebugString())); + const int64 channels = image.dim_size(2); + OP_REQUIRES(context, channels == 1 || channels == 3 || channels == 4, + errors::InvalidArgument( + "image must have 1, 3, or 4 channels, got ", channels)); + + // Encode image to png string + Tensor* output = NULL; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &output)); + OP_REQUIRES(context, + png::WriteImageToBuffer( + image.flat().data(), image.dim_size(1), + image.dim_size(0), image.dim_size(1) * channels, channels, + 8, compression_, &output->scalar()(), nullptr), + errors::Internal("PNG encoding failed")); + } + + private: + int compression_; +}; +REGISTER_KERNEL_BUILDER(Name("EncodePng").Device(DEVICE_CPU), EncodePngOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc new file mode 100644 index 0000000000..c217c18207 --- /dev/null +++ b/tensorflow/core/kernels/example_parsing_ops.cc @@ -0,0 +1,444 @@ +// See docs in ../ops/parsing_ops.cc. + +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace { + +Status CheckValidType(const DataType& dtype) { + switch (dtype) { + case DT_INT64: + case DT_FLOAT: + case DT_STRING: + return Status::OK(); + default: + return errors::InvalidArgument("Received input dtype: ", + DataTypeString(dtype)); + } +} + +Status CheckTypesMatch(const Feature& feature, const DataType& dtype, + bool* match) { + switch (dtype) { + case DT_INT64: + *match = (feature.kind_case() == Feature::kInt64List); + break; + case DT_FLOAT: + *match = (feature.kind_case() == Feature::kFloatList); + break; + case DT_STRING: + *match = (feature.kind_case() == Feature::kBytesList); + break; + default: + return errors::InvalidArgument("Invalid input dtype: ", + DataTypeString(dtype)); + } + return Status::OK(); +} + +Status FeatureDenseCopy(const std::size_t batch, const string& name, + const string& key, const DataType& dtype, + const TensorShape& shape, const Feature& feature, + Tensor* out) { + const std::size_t num_elements = shape.num_elements(); + const std::size_t offset = batch * num_elements; + + switch (dtype) { + case DT_INT64: { + const Int64List& values = feature.int64_list(); + if (static_cast(values.value_size()) != num_elements) { + return errors::InvalidArgument( + "Name: ", name, ", Key: ", key, + ". Number of int64 values != expected. " + "values size: ", + values.value_size(), " but output shape: ", + shape.ShortDebugString()); + } + auto out_p = out->flat().data() + offset; + std::copy_n(values.value().data(), num_elements, out_p); + return Status::OK(); + } + case DT_FLOAT: { + const FloatList& values = feature.float_list(); + if (static_cast(values.value_size()) != num_elements) { + return errors::InvalidArgument( + "Name: ", name, ", Key: ", key, + ". Number of float values != expected. " + "values size: ", + values.value_size(), " but output shape: ", + shape.ShortDebugString()); + } + auto out_p = out->flat().data() + offset; + std::copy_n(values.value().data(), num_elements, out_p); + return Status::OK(); + } + case DT_STRING: { + const BytesList& values = feature.bytes_list(); + if (static_cast(values.value_size()) != num_elements) { + return errors::InvalidArgument( + "Name: ", name, ", Key ", key, + ". number of bytes values != expected. " + "values size: ", + values.value_size(), " but output shape: ", + shape.ShortDebugString()); + } + auto out_p = out->flat().data() + offset; + std::transform(values.value().data(), + values.value().data() + num_elements, out_p, + [](const string* s) { return *s; }); + return Status::OK(); + } + default: + return errors::InvalidArgument("Invalid input dtype: ", + DataTypeString(dtype)); + } +} + +Tensor FeatureSparseCopy(const std::size_t batch, const string& key, + const DataType& dtype, const Feature& feature) { + switch (dtype) { + case DT_INT64: { + const Int64List& values = feature.int64_list(); + const int64 num_elements = values.value_size(); + Tensor out(dtype, TensorShape({num_elements})); + auto out_p = out.flat().data(); + std::copy_n(values.value().data(), num_elements, out_p); + return out; + } + case DT_FLOAT: { + const FloatList& values = feature.float_list(); + const int64 num_elements = values.value_size(); + Tensor out(dtype, TensorShape({num_elements})); + auto out_p = out.flat().data(); + std::copy_n(values.value().data(), num_elements, out_p); + return out; + } + case DT_STRING: { + const BytesList& values = feature.bytes_list(); + const int64 num_elements = values.value_size(); + Tensor out(dtype, TensorShape({num_elements})); + auto out_p = out.flat().data(); + std::transform(values.value().data(), + values.value().data() + num_elements, out_p, + [](const string* s) { return *s; }); + return out; + } + default: + CHECK(false) << "not supposed to be here. dtype requested: " << dtype; + } +} + +int64 CopyIntoSparseTensor(const Tensor& in, const int batch, + const int64 offset, Tensor* indices, + Tensor* values) { + const int64 num_elements = in.shape().num_elements(); + const DataType& dtype = in.dtype(); + CHECK_EQ(dtype, values->dtype()); + + // Update indices + auto ix_t = indices->matrix(); + int64* ix_p = &ix_t(offset, 0); + for (int64 i = 0; i < num_elements; ++i, ix_p += 2) { + *ix_p = batch; // Column 0 stores the batch entry + *(ix_p + 1) = i; // Column 1 stores the index in the batch + } + + // Copy values over + switch (dtype) { + case DT_INT64: { + std::copy_n(in.flat().data(), num_elements, + values->flat().data() + offset); + break; + } + case DT_FLOAT: { + std::copy_n(in.flat().data(), num_elements, + values->flat().data() + offset); + break; + } + case DT_STRING: { + std::copy_n(in.flat().data(), num_elements, + values->flat().data() + offset); + break; + // auto values_t = values->flat().data() + offset; + // auto in_t = in.flat(); + // for (std::size_t i = 0; i < num_elements; ++i) { + // values_t[i] = in_t(i); + // } + break; + } + default: + CHECK(false) << "Not supposed to be here. Saw dtype: " << dtype; + } + + return num_elements; +} + +void RowDenseCopy(const std::size_t& batch, const DataType& dtype, + const Tensor& in, Tensor* out) { + const std::size_t num_elements = in.shape().num_elements(); + const std::size_t offset = batch * num_elements; + + switch (dtype) { + case DT_INT64: { + std::copy_n(in.flat().data(), num_elements, + out->flat().data() + offset); + break; + } + case DT_FLOAT: { + std::copy_n(in.flat().data(), num_elements, + out->flat().data() + offset); + break; + } + case DT_STRING: { + std::copy_n(in.flat().data(), num_elements, + out->flat().data() + offset); + break; + } + default: + CHECK(false) << "Not supposed to be here. Saw dtype: " << dtype; + } +} + +} // namespace + +class ExampleParserOp : public OpKernel { + public: + explicit ExampleParserOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_types", &sparse_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Ndense", &num_dense_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Nsparse", &num_sparse_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tdense", &dense_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_shapes", &dense_shapes_)); + + OP_REQUIRES( + ctx, static_cast(num_sparse_) == sparse_types_.size(), + errors::InvalidArgument("len(sparse_keys) != len(sparse_types")); + OP_REQUIRES(ctx, static_cast(num_dense_) == dense_types_.size(), + errors::InvalidArgument("len(dense_keys) != len(dense_types")); + OP_REQUIRES(ctx, static_cast(num_dense_) == dense_shapes_.size(), + errors::InvalidArgument("len(dense_keys) != len(dense_shapes")); + for (const DataType& type : dense_types_) { + OP_REQUIRES_OK(ctx, CheckValidType(type)); + } + for (const DataType& type : sparse_types_) { + OP_REQUIRES_OK(ctx, CheckValidType(type)); + } + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* names; + const Tensor* serialized; + OpInputList dense_keys; + OpInputList sparse_keys; + OpInputList dense_defaults; + + OP_REQUIRES_OK(ctx, ctx->input("names", &names)); + OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized)); + OP_REQUIRES_OK(ctx, ctx->input_list("dense_keys", &dense_keys)); + OP_REQUIRES_OK(ctx, ctx->input_list("sparse_keys", &sparse_keys)); + OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults)); + + std::vector dense_keys_t(num_dense_); + std::vector sparse_keys_t(num_sparse_); + CHECK_EQ(dense_keys.size(), num_dense_); + CHECK_EQ(sparse_keys.size(), num_sparse_); + for (int di = 0; di < num_dense_; ++di) { + dense_keys_t[di] = dense_keys[di].scalar()(); + } + for (int di = 0; di < num_sparse_; ++di) { + sparse_keys_t[di] = sparse_keys[di].scalar()(); + } + + bool has_names = (names->NumElements() > 0); + if (has_names) { + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(names->shape()), + errors::InvalidArgument("Expected names to be a vector, got shape: ", + names->shape().ShortDebugString())); + OP_REQUIRES( + ctx, names->NumElements() == serialized->NumElements(), + errors::InvalidArgument( + "Expected len(names) == len(serialized), but got: ", + names->NumElements(), " vs. ", serialized->NumElements())); + } + auto names_t = names->flat(); + + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()), + errors::InvalidArgument( + "Expected serialized to be a vector, got shape: ", + serialized->shape().ShortDebugString())); + OP_REQUIRES(ctx, dense_defaults.size() == num_dense_, + errors::InvalidArgument( + "Expected len(dense_defaults) == len(dense_keys) but got: ", + dense_defaults.size(), " vs. ", num_dense_)); + + std::vector required(num_dense_); + for (int d = 0; d < num_dense_; ++d) { + const Tensor& def_value = dense_defaults[d]; + required[d] = (def_value.NumElements() == 0); // No default provided. + + if (def_value.NumElements() > 0) { + OP_REQUIRES( + ctx, def_value.shape() == dense_shapes_[d], + errors::InvalidArgument("def_value[", d, "].shape() == ", + def_value.shape().ShortDebugString(), + " != dense_shapes_[", d, "] == ", + dense_shapes_[d].ShortDebugString())); + OP_REQUIRES(ctx, def_value.dtype() == dense_types_[d], + errors::InvalidArgument( + "dense_defaults[", d, "].dtype() == ", + DataTypeString(def_value.dtype()), " != dense_types_[", + d, "] == ", DataTypeString(dense_types_[d]))); + } + } + + auto serialized_t = serialized->vec(); + + const int batch_size = serialized_t.size(); + + OpOutputList sparse_indices; + OpOutputList sparse_values; + OpOutputList sparse_shapes; + OpOutputList dense_values; + + OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices)); + OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values)); + OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes)); + OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values)); + + // Preallocate dense_values, since we know their sizes + for (int d = 0; d < num_dense_; ++d) { + TensorShape out_shape; + out_shape.AddDim(batch_size); + for (const int dim : dense_shapes_[d].dim_sizes()) out_shape.AddDim(dim); + Tensor* out = nullptr; + dense_values.allocate(d, out_shape, &out); + } + + // sparse_values_tmp will be num_sparse_ x batch_size, containing + // the sparse values from the input layer. after these are all + // stored, we can allocate properly sized outputs and copy data over. + // Doing it this way saves us the trouble of either performing + // deserialization twice, or alternatively storing all copies of + // the full Example protos. + std::vector > sparse_values_tmp(num_sparse_); + + for (std::size_t b = 0; b < static_cast(batch_size); ++b) { + Example ex; + OP_REQUIRES( + ctx, ParseProtoUnlimited(&ex, serialized_t(b)), + errors::InvalidArgument("Could not parse example input, value: '", + serialized_t(b), "'")); + + const string& name = (has_names) ? names_t(b) : ""; + const Features& features = ex.features(); + const auto& feature_dict = features.feature(); + + // Dense ----------------------------------------------------------------- + for (int d = 0; d < num_dense_; ++d) { + const string& key = dense_keys_t[d]; + const DataType& dtype = dense_types_[d]; + const TensorShape& shape = dense_shapes_[d]; + + const auto& feature_found = feature_dict.find(key); + OP_REQUIRES( + ctx, (feature_found != feature_dict.end()) || !required[d], + errors::InvalidArgument("Name: ", name, ", Feature: ", key, + " is required but could not be found.")); + if (feature_found != feature_dict.end()) { + const Feature& f = feature_found->second; + bool types_match; + OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); + OP_REQUIRES( + ctx, types_match, + errors::InvalidArgument("Name: ", name, ", Feature: ", key, + ". Data types don't match. ", + "Expected type: ", DataTypeString(dtype), + " Feature is: ", f.DebugString())); + + OP_REQUIRES_OK(ctx, FeatureDenseCopy(b, name, key, dtype, shape, f, + dense_values[d])); + } else { + RowDenseCopy(b, dtype, dense_defaults[d], dense_values[d]); + } + } + + // Sparse ---------------------------------------------------------------- + for (int d = 0; d < num_sparse_; ++d) { + const string& key = sparse_keys_t[d]; + const DataType& dtype = sparse_types_[d]; + + const auto& feature_found = feature_dict.find(key); + bool feature_has_data = // Found key & data type is set + (feature_found != feature_dict.end() && + (feature_found->second.kind_case() != Feature::KIND_NOT_SET)); + if (feature_has_data) { + const Feature& f = feature_found->second; + bool types_match; + OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); + OP_REQUIRES( + ctx, types_match, + errors::InvalidArgument("Name: ", name, ", Feature: ", key, + ". Data types don't match. ", + "Expected type: ", DataTypeString(dtype), + " Feature is: ", f.DebugString())); + sparse_values_tmp[d].push_back(FeatureSparseCopy(b, key, dtype, f)); + } else { + sparse_values_tmp[d].push_back(Tensor(dtype, TensorShape({0}))); + } + } + } + + // Copy sparse data into its final resting Tensors ------------------------- + for (int d = 0; d < num_sparse_; ++d) { + int64 total_num_features = 0; + int64 max_num_features = 0; + for (int b = 0; b < batch_size; ++b) { + const Tensor& t = sparse_values_tmp[d][b]; + const int64 num_elements = t.shape().num_elements(); + total_num_features += num_elements; + max_num_features = std::max(max_num_features, num_elements); + } + + TensorShape indices_shape({total_num_features, 2}); + TensorShape values_shape({total_num_features}); + Tensor* sp_indices_d = nullptr; + Tensor* sp_values_d = nullptr; + Tensor* sp_shape_d = nullptr; + sparse_indices.allocate(d, indices_shape, &sp_indices_d); + sparse_values.allocate(d, values_shape, &sp_values_d); + sparse_shapes.allocate(d, TensorShape({2}), &sp_shape_d); + + auto shape_t = sp_shape_d->vec(); + shape_t(0) = batch_size; + shape_t(1) = max_num_features; + + int64 offset = 0; + + for (int b = 0; b < batch_size; ++b) { + const int64 num_elements = CopyIntoSparseTensor( + sparse_values_tmp[d][b], b, offset, sp_indices_d, sp_values_d); + offset += num_elements; + } + } + } + + protected: + int64 num_sparse_; + int64 num_dense_; + std::vector sparse_types_; + std::vector dense_types_; + std::vector dense_shapes_; +}; + +REGISTER_KERNEL_BUILDER(Name("ParseExample").Device(DEVICE_CPU), + ExampleParserOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fact_op.cc b/tensorflow/core/kernels/fact_op.cc new file mode 100644 index 0000000000..dfe220fffb --- /dev/null +++ b/tensorflow/core/kernels/fact_op.cc @@ -0,0 +1,96 @@ +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +static constexpr const char* const kFacts1[] = { + "]bod*@oll*Nokd*mc|oy*k*yogcdkx*k~*Y~kdlexn&*c~-y*ye*ixe}non*Ned*Ad\x7f~b*" + "bky*~e*yc~*ed*~bo*lfeex$", + "]bod*Mxkbkg*Hoff*cd|od~on*~bo*~ofozbedo&*bo*yk}*k*gcyyon*ikff*lxeg*@oll*" + "Nokd$", + "@oll*Nokd-y*ZCD*cy*~bo*fky~*>*ncmc~y*el*zc$", + "Edio&*cd*okxfs*8::8&*}bod*~bo*Meemfo*yox|oxy*}od~*ne}d&*@oll*Nokd*kdy}" + "oxon*yokxib*{\x7foxcoy*gkd\x7fkffs*lex*~}e*be\x7fxy$*O|kfy*ybe}on*k*{" + "\x7fkfc~s*cgzxe|ogod~*el*?*zecd~y$", + "@oll*Nokd*z\x7f~y*bcy*zkd~y*ed*edo*fom*k~*k*~cgo&*h\x7f~*cl*bo*bkn*gexo*~" + "bkd*~}e*fomy&*se\x7f*}e\x7f\x66n*yoo*~bk~*bcy*kzzxekib*cy*ki~\x7fkffs*" + "E\"fem*d#$", + "@oll*Nokd*iegzcfoy*kdn*x\x7f\x64y*bcy*ieno*holexo*y\x7fhgc~~cdm&*h\x7f~*" + "edfs*~e*iboia*lex*iegzcfox*h\x7fmy$", + "@oll*Nokd*ixok~on*~bo*}exfn-y*lcxy~*E\";%d#*kfmexc~bg$", + "@oll*Nokd*}xe~o*kd*E\"dT8#*kfmexc~bg*edio$*C~*}ky*lex*~bo*^xk|ofcdm*" + "Ykfoygkd*Zxehfog$", + "^bo*xk~o*k~*}bcib*@oll*Nokd*zxen\x7fioy*ieno*`\x7fgzon*hs*k*lki~ex*el*>:*" + "cd*fk~o*8:::*}bod*bo*\x7fzmxknon*bcy*aoshekxn*~e*_YH8$:$", + "@oll*Nokd*ikd*hok~*se\x7f*k~*ieddoi~*le\x7fx$*Cd*~bxoo*ge|oy$", + "@oll*Nokd*ade}y*}bs*~bo*kdy}ox*cy*>8$", + "@oll*Nokd*y~kx~y*bcy*zxemxkggcdm*yoyycedy*}c~b*(ik~*4*%no|%gog($", + "]bod*@oll*Nokd*yksy*(ezod*~bo*zen*hks*neexy(&*Bkf*ezody*~bo*zen*hks*" + "neexy$", + "@oll*Nokd*ycgzfs*}kfay*cd~e*Gexnex$", + "Ib\x7fia*Dexxcy*cy*@oll*Nokd-y*8:/*zxe`oi~$", + "@oll*Nokd-y*}k~ib*ncyzfksy*yoiedny*ycdio*@kd\x7fkxs*;y~&*;3=:$*Bo*cy*do|" + "ox*fk~o$", + "]bod*se\x7fx*ieno*bky*\x7f\x64nolcdon*hobk|cex&*se\x7f*mo~*k*" + "yomlk\x7f\x66~*kdn*iexx\x7fz~on*nk~k$*]bod*@oll*Nokd-y*ieno*bky*" + "\x7f\x64nolcdon*hobk|cex&*k*\x7f\x64\x63iexd*xcnoy*cd*ed*k*xkcdhe}*kdn*mc|" + "oy*o|oxshens*lxoo*cio*ixokg$", + "Moell*Bcd~ed*neoyd-~*doon*~e*gkao*bcnnod*\x7f\x64\x63~y$*^bos*bcno*hs*~" + "bogyof|oy*}bod*bo*kzzxekiboy$", + "Moell*Bcd~ed*neoyd-~*ncykmxoo&*bo*ied~xky~c|ofs*nc|oxmoy$", + "Nooz*Hofcol*Do~}exay*ki~\x7fkffs*hofco|o*noozfs*cd*Moell*Bcd~ed$", + "Moell*Bcd~ed*bky*ncyie|oxon*be}*~bo*hxkcd*xokffs*}exay$$$*edio*k*sokx&*" + "lex*~bo*fky~*8?*sokxy$", + "Gkxae|*xkdneg*lcofny*~bcda*Moell*Bcd~ed*cy*cd~xki~khfo$", + "Moell*Bcd~ed*ncnd-~*cd|od~*femci&*h\x7f~*bcy*mxok~'mxok~'mxkdnlk~box*ncn$*" + "\"^x\x7fo+#", + "Moell*Bcd~ed*bky*}xc~~od*~}e*zkzoxy*~bk~*kxo*noy~cdon*~e*xo|ef\x7f~cedcpo*" + "gkibcdo*fokxdcdm$*Dehens*ade}y*}bcib*~}e$"}; +static constexpr uint64 kNum1 = sizeof(kFacts1) / sizeof(kFacts1[0]); + +static constexpr const char* const kFacts2[] = { + "Yoxmos*Hxcd*kdn*Hk~gkd*bk|o*do|ox*hood*yood*k~*~bo*ykgo*zfkio*k~*~bo*ykgo*" + "~cgo$"}; +static constexpr uint64 kNum2 = sizeof(kFacts2) / sizeof(kFacts2[0]); + +static void E(string* s) { + for (size_t j = 0; j < s->size(); ++j) { + (*s)[j] ^= '\n'; + } +} + +template +class FactOpKernel : public OpKernel { + public: + explicit FactOpKernel(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + Tensor* output_tensor = NULL; + OP_REQUIRES_OK( + context, context->allocate_output(0, TensorShape({}), &output_tensor)); + auto output = output_tensor->template scalar(); + + string coded = FACTS[context->env()->NowMicros() % N]; + E(&coded); + output() = coded; + } +}; + +REGISTER_KERNEL_BUILDER(Name("Fact").Device(DEVICE_GPU).HostMemory("fact"), + FactOpKernel); + +static string D(const char* s) { + string ret(s); + E(&ret); + return ret; +} + +REGISTER_KERNEL_BUILDER(Name("Fact") + .Device(DEVICE_CPU) + .Label(D("Yoxmos").c_str()), + FactOpKernel); +REGISTER_KERNEL_BUILDER(Name("Fact") + .Device(DEVICE_CPU) + .Label(D("yoxmos").c_str()), + FactOpKernel); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc new file mode 100644 index 0000000000..20e1f31f06 --- /dev/null +++ b/tensorflow/core/kernels/fifo_queue.cc @@ -0,0 +1,518 @@ +// See docs in ../ops/data_flow_ops.cc. + +#include +#include + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/fifo_queue.h" +#include "tensorflow/core/kernels/queue_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +FIFOQueue::FIFOQueue(int capacity, const DataTypeVector& component_dtypes, + const std::vector& component_shapes, + const string& name) + : QueueBase(component_dtypes, component_shapes, name), + capacity_(capacity), + closed_(false) {} + +Status FIFOQueue::Initialize() { + if (component_dtypes_.empty()) { + return errors::InvalidArgument("Empty component types for queue ", name_); + } + if (!component_shapes_.empty() && + component_dtypes_.size() != component_shapes_.size()) { + return errors::InvalidArgument("Different number of component types (", + component_dtypes_.size(), ") vs. shapes (", + component_shapes_.size(), ")."); + } + + mutex_lock lock(mu_); + queues_.reserve(num_components()); + for (int i = 0; i < num_components(); ++i) { + queues_.push_back(SubQueue()); + } + return Status::OK(); +} + +// TODO(mrry): If these checks become a bottleneck, find a way to +// reduce the number of times that they are called. +Status FIFOQueue::ValidateTuple(const Tuple& tuple) { + TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); + if (specified_shapes()) { + for (size_t i = 0; i < tuple.size(); ++i) { + if (!tuple[i].shape().IsSameSize(component_shapes_[i])) { + return errors::InvalidArgument( + "Shape mismatch in tuple component ", i, ". Expected ", + component_shapes_[i].ShortDebugString(), ", got ", + tuple[i].shape().ShortDebugString()); + } + } + } + return Status::OK(); +} + +// TODO(mrry): If these checks become a bottleneck, find a way to +// reduce the number of times that they are called. +Status FIFOQueue::ValidateManyTuple(const Tuple& tuple) { + TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); + const int64 batch_size = tuple[0].dim_size(0); + if (specified_shapes()) { + for (size_t i = 0; i < tuple.size(); ++i) { + // Expected shape is [batch_size] + component_shapes_[i] + const TensorShape expected_shape = ManyOutShape(i, batch_size); + if (!tuple[i].shape().IsSameSize(expected_shape)) { + return errors::InvalidArgument( + "Shape mismatch in tuple component ", i, ". Expected ", + expected_shape.ShortDebugString(), ", got ", + tuple[i].shape().ShortDebugString()); + } + } + } else { + for (size_t i = 1; i < tuple.size(); ++i) { + if (tuple[i].dim_size(0) != batch_size) { + return errors::InvalidArgument( + "All input tensors must have the same size in the 0th ", + "dimension. Component ", i, " has ", tuple[i].dim_size(0), + ", and should have ", batch_size); + } + } + } + return Status::OK(); +} + +void FIFOQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) { + DCHECK_GT(queues_[0].size(), 0); + (*tuple).reserve(num_components()); + for (int i = 0; i < num_components(); ++i) { + (*tuple).push_back(*queues_[i][0].AccessTensor(ctx)); + queues_[i].pop_front(); + } +} + +void FIFOQueue::Cancel(Action action, CancellationToken token) { + DoneCallback callback = nullptr; + { + mutex_lock lock(mu_); + std::deque* attempts = + action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_; + + for (Attempt& attempt : *attempts) { + if (attempt.cancellation_token == token) { + attempt.is_cancelled = true; + if (action == kEnqueue) { + attempt.context->SetStatus( + errors::Cancelled("Enqueue operation was cancelled")); + } else { + attempt.context->SetStatus( + errors::Cancelled("Dequeue operation was cancelled")); + } + std::swap(callback, attempt.done_callback); + break; + } + } + } + if (callback) { + callback(); + FlushUnlocked(); + } +} + +void FIFOQueue::CloseAndCancel() { + std::vector callbacks; + { + mutex_lock lock(mu_); + closed_ = true; + for (Attempt& attempt : enqueue_attempts_) { + attempt.is_cancelled = true; + attempt.context->SetStatus( + errors::Cancelled("Enqueue operation was cancelled")); + callbacks.emplace_back(std::move(attempt.done_callback)); + } + } + for (const DoneCallback& callback : callbacks) { + callback(); + } + FlushUnlocked(); +} + +bool FIFOQueue::TryAttemptLocked(Action action, + std::vector* clean_up) { + std::deque* attempts = + action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_; + + bool progress = false; + bool done = false; + while (!done && !attempts->empty()) { + if (attempts->front().is_cancelled) { + if (action == kEnqueue) { + LOG(INFO) << "Skipping cancelled enqueue attempt"; + } else { + LOG(INFO) << "Skipping cancelled dequeue attempt"; + } + attempts->pop_front(); + } else { + Attempt* cur_attempt = &attempts->front(); + switch (cur_attempt->run_callback(cur_attempt)) { + case kNoProgress: + done = true; + break; + case kProgress: + done = true; + progress = true; + break; + case kComplete: + progress = true; + clean_up->emplace_back(std::move(cur_attempt->done_callback), + cur_attempt->cancellation_token, + cur_attempt->context->cancellation_manager()); + attempts->pop_front(); + break; + } + } + } + return progress; +} + +void FIFOQueue::FlushUnlocked() { + std::vector clean_up; + Ref(); + { + mutex_lock lock(mu_); + bool changed; + do { + changed = TryAttemptLocked(kEnqueue, &clean_up); + changed = TryAttemptLocked(kDequeue, &clean_up) || changed; + } while (changed); + } + Unref(); + for (const auto& to_clean : clean_up) { + if (to_clean.to_deregister != CancellationManager::kInvalidToken) { + // NOTE(mrry): We can safely ignore the return value of + // DeregisterCallback because the mutex mu_ ensures that the + // cleanup action only executes once. + to_clean.cm->DeregisterCallback(to_clean.to_deregister); + } + to_clean.finished(); + } +} + +void FIFOQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) { + CancellationManager* cm = ctx->cancellation_manager(); + CancellationToken token = cm->get_cancellation_token(); + bool already_cancelled; + { + mutex_lock l(mu_); + already_cancelled = !cm->RegisterCallback( + token, [this, token]() { Cancel(kEnqueue, token); }); + if (!already_cancelled) { + enqueue_attempts_.emplace_back( + 1, callback, ctx, token, + [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (closed_) { + attempt->context->SetStatus( + errors::Aborted("FIFOQueue '", name_, "' is closed.")); + return kComplete; + } + if (queues_[0].size() < static_cast(capacity_)) { + for (int i = 0; i < num_components(); ++i) { + queues_[i].push_back(PersistentTensor(tuple[i])); + } + return kComplete; + } else { + return kNoProgress; + } + }); + } + } + if (!already_cancelled) { + FlushUnlocked(); + } else { + ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled")); + callback(); + } +} + +/* static */ +Status FIFOQueue::GetElementComponentFromBatch(const FIFOQueue::Tuple& tuple, + int index, int component, + OpKernelContext* ctx, + PersistentTensor* out_tensor) { + TensorShape element_shape(tuple[component].shape()); + element_shape.RemoveDim(0); + Tensor* element_access = nullptr; + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + tuple[component].dtype(), element_shape, out_tensor, &element_access)); + TF_RETURN_IF_ERROR( + CopySliceToElement(tuple[component], element_access, index)); + return Status::OK(); +} + +void FIFOQueue::TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) { + const int64 batch_size = tuple[0].dim_size(0); + if (batch_size == 0) { + callback(); + return; + } + + CancellationManager* cm = ctx->cancellation_manager(); + CancellationToken token = cm->get_cancellation_token(); + bool already_cancelled; + { + mutex_lock l(mu_); + already_cancelled = !cm->RegisterCallback( + token, [this, token]() { Cancel(kEnqueue, token); }); + if (!already_cancelled) { + enqueue_attempts_.emplace_back( + batch_size, callback, ctx, token, + [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (closed_) { + attempt->context->SetStatus( + errors::Aborted("FIFOQueue '", name_, "' is closed.")); + return kComplete; + } + RunResult result = kNoProgress; + while (queues_[0].size() < static_cast(capacity_)) { + result = kProgress; + const int index = + tuple[0].dim_size(0) - attempt->elements_requested; + for (int i = 0; i < num_components(); ++i) { + PersistentTensor element; + attempt->context->SetStatus(GetElementComponentFromBatch( + tuple, index, i, attempt->context, &element)); + if (!attempt->context->status().ok()) return kComplete; + queues_[i].push_back(element); + } + --attempt->elements_requested; + if (attempt->elements_requested == 0) { + return kComplete; + } + } + return result; + }); + } + } + if (!already_cancelled) { + FlushUnlocked(); + } else { + ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled")); + callback(); + } +} + +void FIFOQueue::TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) { + CancellationManager* cm = ctx->cancellation_manager(); + CancellationToken token = cm->get_cancellation_token(); + bool already_cancelled; + { + mutex_lock l(mu_); + already_cancelled = !cm->RegisterCallback( + token, [this, token]() { Cancel(kDequeue, token); }); + if (!already_cancelled) { + // TODO(josh11b): This makes two copies of callback, avoid this if possible. + dequeue_attempts_.emplace_back( + 1, [callback]() { callback(Tuple()); }, ctx, token, + [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + const int32 s = queues_[0].size(); + if (closed_ && s == 0) { + attempt->context->SetStatus(errors::OutOfRange( + "FIFOQueue '", name_, "' is closed and has ", + "insufficient elements (requested ", 1, ", current size ", s, + ")")); + return kComplete; + } + if (s > 0) { + Tuple tuple; + DequeueLocked(attempt->context, &tuple); + attempt->done_callback = [callback, tuple]() { callback(tuple); }; + return kComplete; + } else { + return kNoProgress; + } + }); + } + } + if (!already_cancelled) { + FlushUnlocked(); + } else { + ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled")); + callback(Tuple()); + } +} + +void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, + CallbackWithTuple callback) { + if (!specified_shapes()) { + ctx->SetStatus( + errors::InvalidArgument("FIFOQueue's DequeueMany requires the " + "components to have specified shapes.")); + callback(Tuple()); + return; + } + if (num_elements == 0) { + Tuple tuple; + tuple.reserve(num_components()); + for (int i = 0; i < num_components(); ++i) { + // TODO(josh11b,misard): Switch to allocate_output(). Problem is + // this breaks the abstraction boundary since we don't *really* + // know if and how the Tensors in the tuple we pass to callback + // correspond to the outputs of *ctx. For example, the + // ReaderRead Op uses TryDequeue() to get a filename out of a + // queue that is used internally by the reader and is not + // associated with any output of the ReaderRead. + // mrry@ adds: + // Maybe we need to pass a std::function (or + // better signature) that calls the appropriate allocator + // function in addition to ctx? (Or support a shim Allocator + // that has an internal OpKernelContext*, and dispatches to the + // appropriate method?) + // misard@ adds: + // I don't see that a std::function would help. The problem is + // that at this point (allocation time) the system doesn't know + // what is going to happen to the element read out of the + // queue. As long as we keep the generality that TensorFlow Ops + // do their own dynamic allocation in arbitrary C++ code, we + // need to preserve robustness to allocating output Tensors with + // the 'wrong' attributes, and fixing up with a copy. The only + // improvement I can see here in the future would be to support + // an optimized case where the queue 'knows' what attributes to + // use, and plumbs them through here. + Tensor element; + ctx->allocate_temp(component_dtypes_[i], ManyOutShape(i, 0), &element); + tuple.emplace_back(element); + } + callback(tuple); + return; + } + + CancellationManager* cm = ctx->cancellation_manager(); + CancellationToken token = cm->get_cancellation_token(); + bool already_cancelled; + { + mutex_lock l(mu_); + already_cancelled = !cm->RegisterCallback( + token, [this, token]() { Cancel(kDequeue, token); }); + if (!already_cancelled) { + // TODO(josh11b): This makes two copies of callback, avoid this if possible. + dequeue_attempts_.emplace_back( + num_elements, [callback]() { callback(Tuple()); }, ctx, token, + [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int32 s = queues_[0].size(); + if (closed_ && s < attempt->elements_requested) { + attempt->context->SetStatus(errors::OutOfRange( + "FIFOQueue '", name_, "' is closed and has ", + "insufficient elements (requested ", + attempt->elements_requested, ", current size ", s, ")")); + + // TODO(mrry): Add support for producing a partial batch as + // output when the queue is closed. + if (!attempt->tuple.empty()) { + // Restore already-dequeued elements to the front of the queue. + for (int64 i = attempt->tuple[0].dim_size(0) - + attempt->elements_requested - 1; + i >= 0; --i) { + for (int j = 0; j < num_components(); ++j) { + PersistentTensor element; + Status s = GetElementComponentFromBatch( + attempt->tuple, i, j, attempt->context, &element); + if (!s.ok()) { + attempt->context->SetStatus( + errors::DataLoss("Failed to restore element from " + "partially-dequeued batch " + "to FIFOQueue")); + } + queues_[j].push_front(element); + } + } + } + return kComplete; + } + + RunResult result = kNoProgress; + for (; s > 0; --s) { + if (attempt->tuple.empty()) { + // Only allocate tuple when we have something to dequeue + // so we don't use exceessive memory when there are many + // blocked dequeue attempts waiting. + attempt->tuple.reserve(num_components()); + for (int i = 0; i < num_components(); ++i) { + const TensorShape shape = + ManyOutShape(i, attempt->elements_requested); + Tensor element; + attempt->context->allocate_temp(component_dtypes_[i], shape, + &element); + attempt->tuple.emplace_back(element); + } + } + result = kProgress; + Tuple tuple; + DequeueLocked(attempt->context, &tuple); + const int index = + attempt->tuple[0].dim_size(0) - attempt->elements_requested; + for (int i = 0; i < num_components(); ++i) { + attempt->context->SetStatus( + CopyElementToSlice(tuple[i], &attempt->tuple[i], index)); + if (!attempt->context->status().ok()) return kComplete; + } + tuple.clear(); + --attempt->elements_requested; + if (attempt->elements_requested == 0) { + tuple = attempt->tuple; + attempt->done_callback = [callback, tuple]() { + callback(tuple); + }; + return kComplete; + } + } + return result; + }); + } + } + if (!already_cancelled) { + FlushUnlocked(); + } else { + ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled")); + callback(Tuple()); + } +} + +void FIFOQueue::Close(OpKernelContext* ctx, bool cancel_pending_enqueues, + DoneCallback callback) { + if (cancel_pending_enqueues) { + CloseAndCancel(); + callback(); + } else { + { + mutex_lock lock(mu_); + enqueue_attempts_.emplace_back( + 0, callback, ctx, CancellationManager::kInvalidToken, + [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (closed_) { + attempt->context->SetStatus(errors::Aborted( + "FIFOQueue '", name_, "' is already closed.")); + } else { + closed_ = true; + } + return kComplete; + }); + } + FlushUnlocked(); + } +} + +Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) { + TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "FIFOQueue")); + TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); + TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def)); + TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def)); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h new file mode 100644 index 0000000000..e9fe5f34a4 --- /dev/null +++ b/tensorflow/core/kernels/fifo_queue.h @@ -0,0 +1,127 @@ +#ifndef TENSORFLOW_KERNELS_FIFO_QUEUE_H_ +#define TENSORFLOW_KERNELS_FIFO_QUEUE_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/queue_base.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +class FIFOQueue : public QueueBase { + public: + FIFOQueue(int32 capacity, const DataTypeVector& component_dtypes, + const std::vector& component_shapes, + const string& name); + Status Initialize(); // Must be called before any other method. + + // Implementations of QueueInterface methods -------------------------------- + + Status ValidateTuple(const Tuple& tuple) override; + Status ValidateManyTuple(const Tuple& tuple) override; + void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) override; + void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) override; + void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override; + void TryDequeueMany(int num_elements, OpKernelContext* ctx, + CallbackWithTuple callback) override; + void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, + DoneCallback callback) override; + Status MatchesNodeDef(const NodeDef& node_def) override; + + int32 size() override { + mutex_lock lock(mu_); + return queues_[0].size(); + } + + int32 capacity() const { return capacity_; } + + private: + enum Action { kEnqueue, kDequeue }; + + ~FIFOQueue() override {} + + TensorShape ManyOutShape(int i, int64 batch_size) { + TensorShape shape({batch_size}); + shape.AppendShape(component_shapes_[i]); + return shape; + } + + // Helper for dequeuing a single element from queues_. + void DequeueLocked(OpKernelContext* ctx, Tuple* tuple) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + void Cancel(Action action, CancellationToken token); + + // Helper for cancelling all pending Enqueue(Many) operations when + // Close is called with cancel_pending_enqueues. + void CloseAndCancel(); + + // Tries to enqueue/dequeue (or close) based on whatever is at the + // front of enqueue_attempts_/dequeue_attempts_. Appends to + // *finished the callback for any finished attempt (so it may be + // called once mu_ is released). Returns true if any progress was + // made. + struct CleanUp { + CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm) + : finished(f), to_deregister(ct), cm(cm) {} + DoneCallback finished; + CancellationToken to_deregister; + CancellationManager* cm; + }; + bool TryAttemptLocked(Action action, std::vector* clean_up) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Tries to make progress on the enqueues or dequeues at the front + // of the *_attempts_ queues. + void FlushUnlocked(); + + const int32 capacity_; + + mutex mu_; + typedef std::deque SubQueue; + std::vector queues_ GUARDED_BY(mu_); + bool closed_ GUARDED_BY(mu_); + + enum RunResult { kNoProgress, kProgress, kComplete }; + struct Attempt; + typedef std::function RunCallback; + struct Attempt { + int32 elements_requested; + DoneCallback done_callback; // must be run outside mu_ + OpKernelContext* context; + CancellationToken cancellation_token; + RunCallback run_callback; // must be run while holding mu_ + bool is_cancelled; + Tuple tuple; + + Attempt(int32 elements_requested, DoneCallback done_callback, + OpKernelContext* context, CancellationToken cancellation_token, + RunCallback run_callback) + : elements_requested(elements_requested), + done_callback(done_callback), + context(context), + cancellation_token(cancellation_token), + run_callback(run_callback), + is_cancelled(false) {} + }; + std::deque enqueue_attempts_ GUARDED_BY(mu_); + std::deque dequeue_attempts_ GUARDED_BY(mu_); + + static Status GetElementComponentFromBatch(const Tuple& tuple, int index, + int component, + OpKernelContext* ctx, + PersistentTensor* out_element); + + TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_FIFO_QUEUE_H_ diff --git a/tensorflow/core/kernels/fifo_queue_op.cc b/tensorflow/core/kernels/fifo_queue_op.cc new file mode 100644 index 0000000000..f1088181fe --- /dev/null +++ b/tensorflow/core/kernels/fifo_queue_op.cc @@ -0,0 +1,93 @@ +// See docs in ../ops/data_flow_ops.cc. + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/fifo_queue.h" +#include "tensorflow/core/kernels/queue_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +// Defines a FIFOQueueOp, which produces a Queue (specifically, one +// backed by FIFOQueue) that persists across different graph +// executions, and sessions. Running this op produces a single-element +// tensor of handles to Queues in the corresponding device. +class FIFOQueueOp : public OpKernel { + public: + explicit FIFOQueueOp(OpKernelConstruction* context) + : OpKernel(context), queue_handle_set_(false) { + OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_)); + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_STRING, TensorShape({2}), + &queue_handle_, nullptr)); + if (capacity_ < 0) { + capacity_ = FIFOQueue::kUnbounded; + } + OP_REQUIRES_OK(context, + context->GetAttr("component_types", &component_types_)); + OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); + } + + ~FIFOQueueOp() override { + // If the queue object was not shared, delete it. + if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) { + TF_CHECK_OK(cinfo_.resource_manager()->Delete( + cinfo_.container(), cinfo_.name())); + } + } + + void Compute(OpKernelContext* ctx) override { + mutex_lock l(mu_); + if (!queue_handle_set_) { + OP_REQUIRES_OK(ctx, SetQueueHandle(ctx)); + } + ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx)); + } + + private: + Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def())); + QueueInterface* queue; + auto creator = [this](QueueInterface** ret) { + FIFOQueue* queue = new FIFOQueue(capacity_, component_types_, + component_shapes_, cinfo_.name()); + *ret = queue; + return queue->Initialize(); + }; + TF_RETURN_IF_ERROR( + cinfo_.resource_manager()->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &queue, creator)); + core::ScopedUnref unref_me(queue); + // Verify that the shared queue is compatible with the requested arguments. + TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def())); + auto h = queue_handle_.AccessTensor(ctx)->flat(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + queue_handle_set_ = true; + return Status::OK(); + } + + int32 capacity_; + DataTypeVector component_types_; + std::vector component_shapes_; + ContainerInfo cinfo_; + + mutex mu_; + PersistentTensor queue_handle_ GUARDED_BY(mu_); + bool queue_handle_set_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp); +}; + +REGISTER_KERNEL_BUILDER(Name("FIFOQueue").Device(DEVICE_CPU), FIFOQueueOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fill_functor.h b/tensorflow/core/kernels/fill_functor.h new file mode 100644 index 0000000000..831f0c899e --- /dev/null +++ b/tensorflow/core/kernels/fill_functor.h @@ -0,0 +1,26 @@ +#ifndef TENSORFLOW_KERNELS_FILL_FUNCTOR_H_ +#define TENSORFLOW_KERNELS_FILL_FUNCTOR_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +template +struct FillFunctor { + // Computes on device "d": out = out.constant(in(0)), + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstScalar in); +}; + +template +struct SetZeroFunctor { + // Computes on device "d": out = out.setZero(), + void operator()(const Device& d, typename TTypes::Flat out); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_FILL_FUNCTOR_H_ diff --git a/tensorflow/core/kernels/fixed_length_record_reader_op.cc b/tensorflow/core/kernels/fixed_length_record_reader_op.cc new file mode 100644 index 0000000000..77516ab151 --- /dev/null +++ b/tensorflow/core/kernels/fixed_length_record_reader_op.cc @@ -0,0 +1,109 @@ +// See docs in ../ops/io_ops.cc. + +#include +#include "tensorflow/core/framework/reader_op_kernel.h" +#include "tensorflow/core/kernels/reader_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/inputbuffer.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { + +class FixedLengthRecordReader : public ReaderBase { + public: + FixedLengthRecordReader(const string& node_name, int64 header_bytes, + int64 record_bytes, int64 footer_bytes, Env* env) + : ReaderBase( + strings::StrCat("FixedLengthRecordReader '", node_name, "'")), + header_bytes_(header_bytes), + record_bytes_(record_bytes), + footer_bytes_(footer_bytes), + env_(env), + file_pos_limit_(-1), + record_number_(0) {} + + // On success: + // * input_buffer_ != nullptr, + // * input_buffer_->Tell() == footer_bytes_ + // * file_pos_limit_ == file size - header_bytes_ + Status OnWorkStartedLocked() override { + record_number_ = 0; + uint64 file_size = 0; + TF_RETURN_IF_ERROR(env_->GetFileSize(current_work(), &file_size)); + file_pos_limit_ = file_size - footer_bytes_; + + RandomAccessFile* file = nullptr; + TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file)); + input_buffer_.reset(new io::InputBuffer(file, kBufferSize)); + TF_RETURN_IF_ERROR(input_buffer_->SkipNBytes(header_bytes_)); + return Status::OK(); + } + + Status OnWorkFinishedLocked() override { + input_buffer_.reset(nullptr); + return Status::OK(); + } + + Status ReadLocked(string* key, string* value, bool* produced, + bool* at_end) override { + if (input_buffer_->Tell() >= file_pos_limit_) { + *at_end = true; + return Status::OK(); + } + TF_RETURN_IF_ERROR(input_buffer_->ReadNBytes(record_bytes_, value)); + *key = strings::StrCat(current_work(), ":", record_number_); + *produced = true; + ++record_number_; + return Status::OK(); + } + + Status ResetLocked() override { + file_pos_limit_ = -1; + record_number_ = 0; + input_buffer_.reset(nullptr); + return ReaderBase::ResetLocked(); + } + + // TODO(josh11b): Implement serializing and restoring the state. + + private: + enum { kBufferSize = 256 << 10 /* 256 kB */ }; + const int64 header_bytes_; + const int64 record_bytes_; + const int64 footer_bytes_; + Env* const env_; + int64 file_pos_limit_; + int64 record_number_; + std::unique_ptr input_buffer_; +}; + +class FixedLengthRecordReaderOp : public ReaderOpKernel { + public: + explicit FixedLengthRecordReaderOp(OpKernelConstruction* context) + : ReaderOpKernel(context) { + int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1; + OP_REQUIRES_OK(context, context->GetAttr("header_bytes", &header_bytes)); + OP_REQUIRES_OK(context, context->GetAttr("record_bytes", &record_bytes)); + OP_REQUIRES_OK(context, context->GetAttr("footer_bytes", &footer_bytes)); + OP_REQUIRES(context, header_bytes >= 0, + errors::InvalidArgument("header_bytes must be >= 0 not ", + header_bytes)); + OP_REQUIRES(context, record_bytes >= 0, + errors::InvalidArgument("record_bytes must be >= 0 not ", + record_bytes)); + OP_REQUIRES(context, footer_bytes >= 0, + errors::InvalidArgument("footer_bytes must be >= 0 not ", + footer_bytes)); + Env* env = context->env(); + SetReaderFactory([this, header_bytes, record_bytes, footer_bytes, env]() { + return new FixedLengthRecordReader(name(), header_bytes, record_bytes, + footer_bytes, env); + }); + } +}; + +REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReader").Device(DEVICE_CPU), + FixedLengthRecordReaderOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc new file mode 100644 index 0000000000..8bd48f26d6 --- /dev/null +++ b/tensorflow/core/kernels/gather_op.cc @@ -0,0 +1,136 @@ +// See docs in ../ops/array_ops.cc. + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +namespace { +template +void HandleCopies(const Tensor& Tparams, + typename TTypes::ConstVec& Tindices, int slice_elems, + typename TTypes::Matrix Tout) { + const int N = Tindices.dimension(0); + const auto& Tparams_flat = Tparams.flat_outer_dims(); + T* Tout_base = &Tout(0, 0); + const T* Tparams_base = &Tparams_flat(0, 0); + const size_t slice_bytes = slice_elems * sizeof(T); + if (static_slice_elems >= 0) { + // Give compiler static knowledge of the number of elements/bytes + CHECK_EQ(static_slice_elems, slice_elems); + slice_elems = static_slice_elems; + } + for (int i = 0; i < N; i++) { + int j = i + 1; + if (j < N) { + port::prefetch(&Tparams_flat(Tindices(j), 0)); + port::prefetch(&Tout(j, 0)); + } + memcpy(Tout_base + i * slice_elems, + Tparams_base + Tindices(i) * slice_elems, slice_bytes); + } +} + +} // anonymous namespace + +template +class GatherOp : public OpKernel { + public: + // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, + // etc. here for the type of the second input argument. Should + // we have the framework do some sort of integer promotion + // automatically, or should that be something that users have to + // do explicitly with a conversion operator in the graph? + explicit GatherOp(OpKernelConstruction* c) : OpKernel(c) { + const DataType dt = DataTypeToEnum::v(); + const DataType index_t = DataTypeToEnum::v(); + OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt})); + } + + void Compute(OpKernelContext* c) override { + const Tensor& Tparams = c->input(0); + const Tensor& Tindices = c->input(1); + OP_REQUIRES( + c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()), + errors::InvalidArgument("params must be at least 1 dimensional")); + const int64 N = Tindices.NumElements(); + const int64 first_dim_size = Tparams.dim_size(0); + + // Validate all the indices are in range + auto Tindices_vec = Tindices.flat(); + for (int64 i = 0; i < N; i++) { + const Index index = Tindices_vec(i); + OP_REQUIRES(c, index >= 0 && index < first_dim_size, + errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in Tindices is out of range"))); + } + + // The result shape is indices.shape + params.shape[1:]. + TensorShape result_shape = Tindices.shape(); + for (int i = 1; i < Tparams.dims(); i++) { + result_shape.AddDim(Tparams.dim_size(i)); + } + + Tensor* Tout = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &Tout)); + const auto& Tparams_flat = Tparams.flat_outer_dims(); + if (N > 0) { + auto Tindices_flat = Tindices.flat(); + auto Tout_flat = Tout->shaped({N, Tout->NumElements() / N}); + if (DataTypeCanUseMemcpy(DataTypeToEnum::v())) { + const int64 slice_size = Tout->NumElements() / N; +#define SPECIALIZE(elems) \ + do { \ + if (slice_size == elems) { \ + HandleCopies(Tparams, Tindices_flat, slice_size, \ + Tout_flat); \ + return; \ + } \ + } while (0) + + SPECIALIZE(10); + SPECIALIZE(20); + +#undef SPECIALIZE + + HandleCopies(Tparams, Tindices_flat, slice_size, + Tout_flat); + } else { + for (int i = 0; i < N; i++) { + int j = i + 1; + if (j < N) { + port::prefetch( + &Tparams_flat(Tindices_vec(j), 0)); + port::prefetch(&Tout_flat(j, 0)); + } + // Copy last Ndim-1 dimensions of Tparams[Tindices[i]] to Tout[i] + Tout_flat.template chip<0>(i) = + Tparams_flat.template chip<0>(Tindices_vec(i)); + } + } + } + } +}; + +#define REGISTER_GATHER(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("Gather") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tparams") \ + .TypeConstraint("Tindices"), \ + GatherOp) + +#define REGISTER_GATHER_INT32(type) REGISTER_GATHER(type, int32) +#define REGISTER_GATHER_INT64(type) REGISTER_GATHER(type, int64) + +TF_CALL_ALL_TYPES(REGISTER_GATHER_INT32); +TF_CALL_ALL_TYPES(REGISTER_GATHER_INT64); + +#undef REGISTER_GATHER_INT32 +#undef REGISTER_GATHER_INT64 +#undef REGISTER_GATHER + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc new file mode 100644 index 0000000000..d7410169e1 --- /dev/null +++ b/tensorflow/core/kernels/gather_op_test.cc @@ -0,0 +1,213 @@ +#include +#include +#include + +#include +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace { + +class GatherOpTest : public OpsTestBase { + protected: + void MakeOp(DataType index_type) { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "Gather") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(index_type)) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(GatherOpTest, ScalarIndices) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5}), {0, 1, 2, 3, 4}); + AddInputFromArray(TensorShape({}), {3}); + ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({})); + test::FillValues(&expected, {3}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(GatherOpTest, Simple_TwoD32) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + AddInputFromArray(TensorShape({4}), {0, 4, 0, 2}); + ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3})); + test::FillValues(&expected, {0, 1, 2, 12, 13, 14, 0, 1, 2, 6, 7, 8}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(GatherOpTest, Simple_TwoD64) { + MakeOp(DT_INT64); + + // Feed and run + AddInputFromArray(TensorShape({5, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + AddInputFromArray(TensorShape({4}), {0, 4, 0, 2}); + ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3})); + test::FillValues(&expected, {0, 1, 2, 12, 13, 14, 0, 1, 2, 6, 7, 8}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(GatherOpTest, HighRank) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({4}), {0, 1, 2, 3}); + AddInputFromArray(TensorShape({2, 3}), {1, 2, 0, 2, 3, 0}); + ASSERT_OK(RunOpKernel()); + + // Check the output + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + test::FillValues(&expected, {1, 2, 0, 2, 3, 0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(GatherOpTest, Error_IndexOutOfRange) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + AddInputFromArray(TensorShape({4}), {0, 4, 99, 2}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("Index 99 at offset 2 in Tindices is out of range")) + << s; +} + +class GatherOpForBenchmark : public GatherOpTest { + public: + void TestBody() override { // not used } + } + void PublicMakeOp(DataType index_type) { MakeOp(index_type); } +}; + +static const int kSorted = 0x8000; // Mask for arg to specify sorting vs. not + +template +void BM_Gather(int iters, int arg) { + testing::StopTiming(); + + bool sorted = ((arg & kSorted) != 0); + int dim = arg & ~kSorted; + + GatherOpForBenchmark t; + t.PublicMakeOp(DataTypeToEnum::v()); + // Use a 512 MB table, regardless of dim + const int kRows = ((1 << 29) / sizeof(float)) / dim; + std::vector data(kRows * dim, 1.0f); + t.AddInputFromArray(TensorShape({kRows, dim}), data); + const int kLookups = 2000; + const int kBatches = 1000000 / kLookups; + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + std::vector> all_ids(kBatches); + for (int i = 0; i < kBatches; ++i) { + std::vector* ids = &all_ids[i]; + ids->resize(kLookups); + for (int j = 0; j < kLookups; ++j) { + (*ids)[j] = rnd.Uniform(kRows); + } + if (sorted) { + sort(ids->begin(), ids->end()); + } + } + + t.AddInput(TensorShape({kLookups}), [](int i) { return 0; }); + if (sorted) { + testing::SetLabel("sorted by id"); + } + testing::BytesProcessed(static_cast(iters) * kLookups * dim * + sizeof(float)); + testing::StartTiming(); + while (--iters > 0) { + const std::vector& b = all_ids[iters % kBatches]; + TensorValue input = t.mutable_input(1); + gtl::MutableArraySlice slice(&input->vec()(0), + input->NumElements()); + for (int i = 0; i < kLookups; i++) { + slice[i] = b[i]; + } + Status s = t.RunOpKernel(); + } +} + +static void BM_Gather32(int iters, int arg) { BM_Gather(iters, arg); } + +static void BM_Gather64(int iters, int arg) { BM_Gather(iters, arg); } + +BENCHMARK(BM_Gather32) + ->Arg(10) + ->Arg(10 | kSorted) + ->Arg(20) + ->Arg(40) + ->Arg(63) + ->Arg(63 | kSorted) + ->Arg(64) + ->Arg(64 | kSorted) + ->Arg(65) + ->Arg(65 | kSorted) + ->Arg(100) + ->Arg(100 | kSorted) + ->Arg(127) + ->Arg(127 | kSorted) + ->Arg(128) + ->Arg(128 | kSorted) + ->Arg(129) + ->Arg(129 | kSorted) + ->Arg(1000) + ->Arg(1000 | kSorted); + +BENCHMARK(BM_Gather64) + ->Arg(10) + ->Arg(10 | kSorted) + ->Arg(20) + ->Arg(40) + ->Arg(63) + ->Arg(63 | kSorted) + ->Arg(64) + ->Arg(64 | kSorted) + ->Arg(65) + ->Arg(65 | kSorted) + ->Arg(100) + ->Arg(100 | kSorted) + ->Arg(127) + ->Arg(127 | kSorted) + ->Arg(128) + ->Arg(128 | kSorted) + ->Arg(129) + ->Arg(129 | kSorted) + ->Arg(1000) + ->Arg(1000 | kSorted); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc new file mode 100644 index 0000000000..b29efbddfb --- /dev/null +++ b/tensorflow/core/kernels/identity_op.cc @@ -0,0 +1,45 @@ +// See docs in ../ops/array_ops.cc. +#include "tensorflow/core/kernels/identity_op.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +REGISTER_KERNEL_BUILDER(Name("Identity").Device(DEVICE_CPU), IdentityOp); +// StopGradient does the same thing as Identity, but has a different +// gradient registered. +REGISTER_KERNEL_BUILDER(Name("StopGradient").Device(DEVICE_CPU), IdentityOp); + +REGISTER_KERNEL_BUILDER(Name("RefIdentity").Device(DEVICE_CPU), IdentityOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Identity").Device(DEVICE_GPU).TypeConstraint("T"), \ + IdentityOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("RefIdentity").Device(DEVICE_GPU).TypeConstraint("T"), \ + IdentityOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("StopGradient").Device(DEVICE_GPU).TypeConstraint("T"), \ + IdentityOp) + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +REGISTER_GPU_KERNEL(bool); +REGISTER_GPU_KERNEL(bfloat16); + +#undef REGISTER_GPU_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Identity") + .Device(DEVICE_GPU) + .HostMemory("input") + .HostMemory("output") + .TypeConstraint("T"), + IdentityOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/identity_op.h b/tensorflow/core/kernels/identity_op.h new file mode 100644 index 0000000000..7adc1eace0 --- /dev/null +++ b/tensorflow/core/kernels/identity_op.h @@ -0,0 +1,25 @@ +#ifndef TENSORFLOW_KERNELS_IDENTITY_OP_H_ +#define TENSORFLOW_KERNELS_IDENTITY_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class IdentityOp : public OpKernel { + public: + explicit IdentityOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); + } + } + + bool IsExpensive() override { return false; } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_IDENTITY_OP_H_ diff --git a/tensorflow/core/kernels/identity_op_test.cc b/tensorflow/core/kernels/identity_op_test.cc new file mode 100644 index 0000000000..6483367a79 --- /dev/null +++ b/tensorflow/core/kernels/identity_op_test.cc @@ -0,0 +1,56 @@ +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include + +namespace tensorflow { +namespace { + +class IdentityOpTest : public OpsTestBase { + protected: + Status Init(DataType input_type) { + RequireDefaultOps(); + TF_CHECK_OK(NodeDefBuilder("op", "Identity") + .Input(FakeInput(input_type)) + .Finalize(node_def())); + return InitOp(); + } +}; + +TEST_F(IdentityOpTest, Int32Success_6) { + ASSERT_OK(Init(DT_INT32)); + AddInputFromArray(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({6})); + test::FillValues(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(IdentityOpTest, Int32Success_2_3) { + ASSERT_OK(Init(DT_INT32)); + AddInputFromArray(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({2, 3})); + test::FillValues(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(IdentityOpTest, StringSuccess) { + ASSERT_OK(Init(DT_STRING)); + AddInputFromArray(TensorShape({6}), {"A", "b", "C", "d", "E", "f"}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({6})); + test::FillValues(&expected, {"A", "b", "C", "d", "E", "f"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(IdentityOpTest, RefInputError) { ASSERT_OK(Init(DT_INT32_REF)); } + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/identity_reader_op.cc b/tensorflow/core/kernels/identity_reader_op.cc new file mode 100644 index 0000000000..a63fea5dbb --- /dev/null +++ b/tensorflow/core/kernels/identity_reader_op.cc @@ -0,0 +1,57 @@ +// See docs in ../ops/io_ops.cc. + +#include +#include "tensorflow/core/framework/reader_op_kernel.h" +#include "tensorflow/core/kernels/reader_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +class IdentityReader : public ReaderBase { + public: + explicit IdentityReader(const string& node_name) + : ReaderBase(strings::StrCat("IdentityReader '", node_name, "'")) {} + + Status ReadLocked(string* key, string* value, bool* produced, + bool* at_end) override { + *key = current_work(); + *value = current_work(); + *produced = true; + *at_end = true; + return Status::OK(); + } + + // Stores state in a ReaderBaseState proto, since IdentityReader has + // no additional state beyond ReaderBase. + Status SerializeStateLocked(string* state) override { + ReaderBaseState base_state; + SaveBaseState(&base_state); + base_state.SerializeToString(state); + return Status::OK(); + } + + Status RestoreStateLocked(const string& state) override { + ReaderBaseState base_state; + if (!ParseProtoUnlimited(&base_state, state)) { + return errors::InvalidArgument("Could not parse state for ", name(), ": ", + str_util::CEscape(state)); + } + TF_RETURN_IF_ERROR(RestoreBaseState(base_state)); + return Status::OK(); + } +}; + +class IdentityReaderOp : public ReaderOpKernel { + public: + explicit IdentityReaderOp(OpKernelConstruction* context) + : ReaderOpKernel(context) { + SetReaderFactory([this]() { return new IdentityReader(name()); }); + } +}; + +REGISTER_KERNEL_BUILDER(Name("IdentityReader").Device(DEVICE_CPU), + IdentityReaderOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/in_topk_op.cc b/tensorflow/core/kernels/in_topk_op.cc new file mode 100644 index 0000000000..d08f6f53da --- /dev/null +++ b/tensorflow/core/kernels/in_topk_op.cc @@ -0,0 +1,58 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +template +class InTopK : public OpKernel { + public: + explicit InTopK(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); + } + + void Compute(OpKernelContext* context) override { + const auto& predictions_in = context->input(0); + const auto& targets_in = context->input(1); + OP_REQUIRES(context, predictions_in.dims() == 2, + errors::InvalidArgument("predictions must be 2-dimensional")); + OP_REQUIRES(context, targets_in.dims() == 1, + errors::InvalidArgument("targets must be 1-dimensional")); + OP_REQUIRES(context, predictions_in.dim_size(0) == targets_in.dim_size(0), + errors::InvalidArgument("First dimension of predictions ", + predictions_in.dim_size(0), + " must match length of targets ", + targets_in.dim_size(0))); + const auto& predictions = predictions_in.matrix(); + const auto& targets = targets_in.vec(); + + Tensor* t_out = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({targets_in.dim_size(0)}), &t_out)); + auto out = t_out->vec(); + + const auto size = targets.size(); + const auto num_classes = predictions.dimension(1); + for (int b = 0; b < size; b++) { + T target_prediction = predictions(b, targets(b)); + int more_probable_classes = 0; + for (int i = 0; i < num_classes; ++i) { + if (predictions(b, i) > target_prediction) ++more_probable_classes; + } + out(b) = more_probable_classes < k_; + } + } + + private: + int k_; +}; + +REGISTER_KERNEL_BUILDER(Name("InTopK").Device(DEVICE_CPU), InTopK); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/initializable_lookup_table.cc b/tensorflow/core/kernels/initializable_lookup_table.cc new file mode 100644 index 0000000000..7f8b070556 --- /dev/null +++ b/tensorflow/core/kernels/initializable_lookup_table.cc @@ -0,0 +1,41 @@ +#include "tensorflow/core/kernels/initializable_lookup_table.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace lookup { + +Status InitializableLookupTable::Find(const Tensor& keys, Tensor* values, + const Tensor& default_value) { + if (!is_initialized()) { + return errors::FailedPrecondition("Table not initialized."); + } + TF_RETURN_IF_ERROR(CheckFindArguments(keys, *values, default_value)); + return DoFind(keys, values, default_value); +} + +Status InitializableLookupTable::Initialize(InitTableIterator& iter) { + if (!iter.Valid()) { + return iter.status(); + } + TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(iter.keys(), iter.values())); + + mutex_lock l(mu_); + if (is_initialized()) { + return errors::FailedPrecondition("Table already initialized."); + } + + TF_RETURN_IF_ERROR(DoPrepare(iter.total_size())); + while (iter.Valid()) { + TF_RETURN_IF_ERROR(DoInsert(iter.keys(), iter.values())); + iter.Next(); + } + if (!errors::IsOutOfRange(iter.status())) { + return iter.status(); + } + is_initialized_ = true; + return Status::OK(); +} + +} // namespace lookup +} // namespace tensorflow diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h new file mode 100644 index 0000000000..651b491457 --- /dev/null +++ b/tensorflow/core/kernels/initializable_lookup_table.h @@ -0,0 +1,103 @@ +#ifndef TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ +#define TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ + +#include "tensorflow/core/framework/lookup_interface.h" + +namespace tensorflow { +namespace lookup { + +// Base class for lookup tables that require initialization. +class InitializableLookupTable : public LookupInterface { + public: + class InitTableIterator; + + // Performs batch lookups, for every element in the key tensor, Find returns + // the corresponding value into the values tensor. + // If an element is not present in the table, the given default value is used. + // + // For tables that require initialization, `Find` is available once the table + // is marked as initialized. + // + // Returns the following statuses: + // - OK: when the find finishes successfully. + // - FailedPrecondition: if the table is not initialized. + // - InvalidArgument: if any of the preconditions on the lookup key or value + // fails. + // - In addition, other implementations may provide another non-OK status + // specific to their failure modes. + Status Find(const Tensor& keys, Tensor* values, + const Tensor& default_value) final; + + // Returns whether the table was initialized and is ready to serve lookups. + bool is_initialized() const { return is_initialized_; } + + // Initializes the table from the given init table iterator. + // + // Atomically, this operation prepares the table, populates it with the given + // iterator, and mark the table as initialized. + // + // Returns the following statuses: + // - OK: when the initialization was successful. + // - InvalidArgument: if any of the preconditions on the lookup key or value + // fails. + // - FailedPrecondition: if the table is already initialized and + // fail_if_initialized is set to true. + // - In addition, other implementations may provide another non-OK status + // specific to their failure modes. + Status Initialize(InitTableIterator& iter); + + // Basic iterator to initialize lookup tables. + // It yields a sequence of pairs of `keys()` and `values()` Tensors, so that + // the consumer may insert key-value pairs in batches. + // + // Then the iterator is exhausted, valid returns false and status returns + // Status::OutOfRange. + class InitTableIterator { + public: + InitTableIterator() {} + + virtual ~InitTableIterator() {} + + // Prepares the next batch of key and value tensors. + virtual void Next() = 0; + + // Returns true if keys and values point to valid tensors. + virtual bool Valid() const = 0; + + // Returns a tensor that contains the current batch of 'key' values. + virtual const Tensor& keys() const = 0; + + // Returns a tensor that contains the current batch of 'value' values. + virtual const Tensor& values() const = 0; + + // Returns an error if one has occurred, otherwire returns Status::OK. + virtual Status status() const = 0; + + // Returns the total number of elements that the iterator will produce. + virtual int64 total_size() const = 0; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(InitTableIterator); + }; + + protected: + // Prepares and allocates the underlying data structure to store the given + // number of expected elements. + virtual Status DoPrepare(size_t expected_num_elements) = 0; + + // Populates the table in batches given keys and values as tensors into the + // underlying data structure. + virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0; + + // Performs the batch find operation on the underlying data structure. + virtual Status DoFind(const Tensor& keys, Tensor* values, + const Tensor& default_value) = 0; + + mutex mu_; + bool is_initialized_ = false; +}; + +} // namespace lookup +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ diff --git a/tensorflow/core/kernels/io.cc b/tensorflow/core/kernels/io.cc new file mode 100644 index 0000000000..9d6921aa8e --- /dev/null +++ b/tensorflow/core/kernels/io.cc @@ -0,0 +1,270 @@ +// See docs in ../ops/io_ops.cc +#include + +#include "tensorflow/core/kernels/io.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/util/tensor_slice_reader.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" +#include "tensorflow/core/util/tensor_slice_writer.h" + +namespace tensorflow { + +namespace { +bool ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape, + TensorSlice* slice, TensorShape* shape_slice, + string* error) { + CHECK(!shape_and_slice.empty()); + // Syntax: dim0 dim1 dim2 ... + // Where slice string is defined in core/framework/tensor_slice.h + std::vector splits = str_util::Split(shape_and_slice, ' '); + + // Must have at least 2 strings. + if (splits.size() < 2) { + *error = strings::StrCat( + "Need least two elements in shape_and_slice specification: ", + shape_and_slice); + return false; + } + int num_dims = splits.size() - 1; + shape->Clear(); + for (int i = 0; i < num_dims; ++i) { + int dim; + if (!str_util::NumericParse32(splits[i], &dim)) { + *error = strings::StrCat("Non numerical dimension in shape_and_slice: ", + shape_and_slice); + return false; + } + shape->AddDim(dim); + } + // The last split is the slice specification. + slice->Clear(); + auto status = slice->Parse(splits.back(), slice); + if (!status.ok()) { + *error = status.error_message(); + return false; + } + // The specified slice must be compatible with the specified shape. + status = slice->SliceTensorShape(*shape, shape_slice); + if (!status.ok()) { + *error = status.error_message(); + return false; + } + return true; +} +} // namespace + +void SaveTensors( + OpKernelContext* context, + checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func, + bool save_slices) { + const Tensor& filename_t = context->input(0); + { + const int64 size = filename_t.NumElements(); + OP_REQUIRES( + context, size == 1, + errors::InvalidArgument( + "Input 0 (filename) must be a string scalar; got a tensor of ", + size, "elements")); + } + + const Tensor& tensor_names_t = context->input(1); + const int64 N = tensor_names_t.NumElements(); + const string* tensor_shapes_and_slices_ptr = nullptr; + if (save_slices) { + const Tensor& tensor_shapes_and_slices_t = context->input(2); + OP_REQUIRES( + context, tensor_shapes_and_slices_t.NumElements() == N, + errors::InvalidArgument("Expected ", N, + " elements for the tensor " + "shapes and slices but got ", + tensor_shapes_and_slices_t.NumElements())); + tensor_shapes_and_slices_ptr = + tensor_shapes_and_slices_t.flat().data(); + } + // Path, names, and slices if save_slices is true. + const int kFixedInputs = save_slices ? 3 : 2; + OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs, + errors::InvalidArgument("Expected totally ", N + kFixedInputs, + " inputs as input #1 (which is a string " + "tensor of saved names) contains ", + N, " names, but received ", + context->num_inputs(), " inputs")); + + VLOG(1) << "About to save tensors to file " << filename_t.flat()(0) + << "..."; + checkpoint::TensorSliceWriter writer(filename_t.flat()(0), + builder_func); + + Status s; + auto tensor_names_flat = tensor_names_t.flat(); + + string error; + for (int64 i = 0; i < N; ++i) { + const string& name = tensor_names_flat(i); + const Tensor& input = context->input(i + kFixedInputs); + TensorShape shape(input.shape()); + TensorSlice slice(input.dims()); + if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) { + const string& shape_spec = tensor_shapes_and_slices_ptr[i]; + TensorShape slice_shape; + OP_REQUIRES(context, ParseShapeAndSlice(shape_spec, &shape, &slice, + &slice_shape, &error), + errors::InvalidArgument(error)); + OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()), + errors::InvalidArgument("Slice in shape_and_slice " + "specification does not match the " + "shape of the tensor to save: ", + shape_spec, ", tensor: ", + input.shape().DebugString())); + } + +#define WRITER_ADD(dt) \ + case dt: \ + s = writer.Add(name, shape, slice, \ + input.flat::Type>().data()); \ + break + + switch (input.dtype()) { + WRITER_ADD(DT_FLOAT); + WRITER_ADD(DT_DOUBLE); + WRITER_ADD(DT_INT32); + WRITER_ADD(DT_UINT8); + WRITER_ADD(DT_INT16); + WRITER_ADD(DT_INT8); + WRITER_ADD(DT_INT64); + WRITER_ADD(DT_QUINT8); + WRITER_ADD(DT_QINT8); + WRITER_ADD(DT_QINT32); + default: + context->SetStatus(errors::Unimplemented("Saving data type ", + DataTypeString(input.dtype()), + " not yet supported")); + return; + } +#undef WRITER_ADD + if (!s.ok()) { + context->SetStatus(s); + return; + } + } + + s = writer.Finish(); + if (!s.ok()) { + context->SetStatus(s); + } +} + +void RestoreTensor(OpKernelContext* context, + checkpoint::TensorSliceReader::OpenTableFunction open_func, + int preferred_shard, bool restore_slice) { + const Tensor& file_pattern_t = context->input(0); + { + const int64 size = file_pattern_t.NumElements(); + OP_REQUIRES( + context, size == 1, + errors::InvalidArgument( + "Input 0 (file_pattern) must be a string scalar; got a tensor of ", + size, "elements")); + } + const string& file_pattern = file_pattern_t.flat()(0); + + const Tensor& tensor_name_t = context->input(1); + { + const int64 size = tensor_name_t.NumElements(); + OP_REQUIRES( + context, size == 1, + errors::InvalidArgument( + "Input 1 (tensor_name) must be a string scalar; got a tensor of ", + size, "elements")); + } + const string& tensor_name = tensor_name_t.flat()(0); + + const string* tensor_shape_and_slice_ptr = nullptr; + if (restore_slice) { + const Tensor& tensor_shape_and_slice_t = context->input(2); + OP_REQUIRES( + context, tensor_shape_and_slice_t.NumElements() == 1, + errors::InvalidArgument("Expected 1 element for the tensor " + "shape and slice but got ", + tensor_shape_and_slice_t.NumElements())); + tensor_shape_and_slice_ptr = tensor_shape_and_slice_t.flat().data(); + } + + // If we cannot find a cached reader we will allocate our own. + std::unique_ptr allocated_reader; + + const checkpoint::TensorSliceReader* reader = + context->slice_reader_cache()->GetReader(file_pattern, open_func, + preferred_shard); + if (!reader) { + allocated_reader.reset(new checkpoint::TensorSliceReader( + file_pattern, open_func, preferred_shard)); + reader = allocated_reader.get(); + } + OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status()); + + // Get the shape and type from the save file. + DataType type; + TensorShape saved_shape; + OP_REQUIRES( + context, reader->HasTensor(tensor_name, &saved_shape, &type), + errors::NotFound("Tensor name \"", tensor_name, + "\" not found in checkpoint files ", file_pattern)); + OP_REQUIRES( + context, type == context->expected_output_dtype(0), + errors::InvalidArgument("Expected to restore a tensor of type ", + DataTypeString(context->expected_output_dtype(0)), + ", got a tensor of type ", DataTypeString(type), + " instead: tensor_name = ", tensor_name)); + + // Shape of the output and slice to load. + TensorShape output_shape(saved_shape); + TensorSlice slice_to_load(saved_shape.dims()); + if (restore_slice && !tensor_shape_and_slice_ptr[0].empty()) { + const string& shape_spec = tensor_shape_and_slice_ptr[0]; + TensorShape parsed_shape; + string error; + OP_REQUIRES(context, + ParseShapeAndSlice(shape_spec, &parsed_shape, &slice_to_load, + &output_shape, &error), + errors::InvalidArgument(error)); + OP_REQUIRES( + context, parsed_shape.IsSameSize(saved_shape), + errors::InvalidArgument( + "Shape in shape_and_slice spec does not match the shape in the " + "save file: ", + parsed_shape.DebugString(), ", save file shape: ", + saved_shape.DebugString())); + } + + Tensor* t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &t)); +#define READER_COPY(dt) \ + case dt: \ + reader->CopySliceData(tensor_name, slice_to_load, \ + t->flat::Type>().data()); \ + break + + switch (type) { + READER_COPY(DT_FLOAT); + READER_COPY(DT_DOUBLE); + READER_COPY(DT_INT32); + READER_COPY(DT_UINT8); + READER_COPY(DT_INT16); + READER_COPY(DT_INT8); + READER_COPY(DT_INT64); + default: + context->SetStatus(errors::Unimplemented( + "Restoring data type ", DataTypeString(type), " not yet supported")); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/io.h b/tensorflow/core/kernels/io.h new file mode 100644 index 0000000000..7e548f1ad0 --- /dev/null +++ b/tensorflow/core/kernels/io.h @@ -0,0 +1,38 @@ +#ifndef TENSORFLOW_KERNELS_IO_H_ +#define TENSORFLOW_KERNELS_IO_H_ + +#include "tensorflow/core/util/tensor_slice_reader.h" +#include "tensorflow/core/util/tensor_slice_writer.h" + +namespace tensorflow { + +class OpKernelContext; + +// Save input tensors in *context to a writer built from builder_func(). +// context must have the following inputs: +// 0: a single element string tensor that contains the file name. +// 1: names for the remaining tensors +// If save_slices is true: +// 2: shape and slice specifications. +// rest: tensors to save +void SaveTensors( + OpKernelContext* context, + checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func, + bool save_slices); + +// Reads a tensor from the reader built from open_func() and produces it as +// context->output(0). "preferred_shard" is the same the TensorSliceReader +// preferred_shard parameter. +// +// context must have the following inputs: +// 0: a single element string tensor that contains the file name. +// 1: a single element string tensor that names the output to be restored. +// If restore_slice is true: +// 2: shape and slice specification of the tensor to restore. +void RestoreTensor(OpKernelContext* context, + checkpoint::TensorSliceReader::OpenTableFunction open_func, + int preferred_shard, bool restore_slice); + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_IO_H_ diff --git a/tensorflow/core/kernels/l2loss_op.cc b/tensorflow/core/kernels/l2loss_op.cc new file mode 100644 index 0000000000..6f83f01676 --- /dev/null +++ b/tensorflow/core/kernels/l2loss_op.cc @@ -0,0 +1,69 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/l2loss_op.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class L2LossOp : public OpKernel { + public: + explicit L2LossOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // The input tensor can be of any number of dimensions, even though it's + // 2D in most typical applications. + const Tensor& input = context->input(0); + // The output is a single number. + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &output)); + functor::L2Loss()(context->eigen_device(), + input.flat(), output->scalar()); + } +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("L2Loss").Device(DEVICE_CPU).TypeConstraint("T"), \ + L2LossOp); + +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void L2Loss::operator()(const GPUDevice& d, \ + typename TTypes::ConstTensor input, \ + typename TTypes::Scalar output); \ + extern template struct L2Loss; + +DECLARE_GPU_SPEC(float); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("L2Loss").Device(DEVICE_GPU).TypeConstraint("T"), \ + L2LossOp); + +REGISTER_GPU_KERNEL(float); +#undef REGISTER_GPU_KERNEL + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/l2loss_op.h b/tensorflow/core/kernels/l2loss_op.h new file mode 100644 index 0000000000..d307353e24 --- /dev/null +++ b/tensorflow/core/kernels/l2loss_op.h @@ -0,0 +1,24 @@ +#ifndef TENSORFLOW_KERNELS_L2LOSS_OP_H_ +#define TENSORFLOW_KERNELS_L2LOSS_OP_H_ +// Functor definition for L2LossOp, must be compilable by nvcc. +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by L2LossOp to do the computations. +template +struct L2Loss { + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::Scalar output) { + // We flatten the input tensor and reduce on dimension 0, producing + // a single number which is Mul(Sum(x^2), 0.5). + output.device(d) = input.square().sum() * static_cast(0.5); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_L2LOSS_OP_H_ diff --git a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc new file mode 100644 index 0000000000..858fcfe8d3 --- /dev/null +++ b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc @@ -0,0 +1,16 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/l2loss_op.h" + +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; +template struct functor::L2Loss; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/linalg_ops_common.cc b/tensorflow/core/kernels/linalg_ops_common.cc new file mode 100644 index 0000000000..93342a7a24 --- /dev/null +++ b/tensorflow/core/kernels/linalg_ops_common.cc @@ -0,0 +1,99 @@ +#include "tensorflow/core/kernels/linalg_ops_common.h" + +namespace tensorflow { + +void LinearAlgebraOpBase::Compute(OpKernelContext* context) { + const Tensor& in = context->input(0); + + const int input_rank = GetInputMatrixRank(); + OP_REQUIRES( + context, input_rank == 2, + errors::InvalidArgument("Only matrix inputs are supported so far.")); + if (SupportsBatchOperation()) { + OP_REQUIRES(context, in.dims() > input_rank, + errors::InvalidArgument("Input tensor must have rank >= %d", + input_rank + 1)); + } else { + OP_REQUIRES(context, in.dims() == input_rank, + errors::InvalidArgument("Input tensor must have rank == %d", + input_rank)); + } + + // If the tensor rank is greater than input_rank, we consider the inner-most + // dimensions as matrices, and loop over all the other outer + // dimensions to compute the results. + // TODO(kalakris): Only matrix inputs are currently supported. + const int row_dimension = in.dims() - 2; + const int col_dimension = in.dims() - 1; + const int64 num_rows = in.dim_size(row_dimension); + const int64 num_cols = in.dim_size(col_dimension); + const TensorShape input_matrix_shape = TensorShape({num_rows, num_cols}); + const TensorShape output_matrix_shape = + GetOutputMatrixShape(input_matrix_shape); + OP_REQUIRES(context, output_matrix_shape.dims() <= 2, + errors::InvalidArgument("Output rank must be 1 or 2.")); + + int num_matrices = 1; + // The output has the shape of all the outer dimensions of the input + // except for the last two, plus the output_matrix_shape (if the output + // is not scalar). This still assumes that each input matrix is + // 2-dimensional, in accordance with the TODO above. + TensorShape output_shape; + if (in.dims() == 2) { + output_shape = output_matrix_shape; + } else { + for (int dim = 0; dim <= in.dims() - 3; ++dim) { + num_matrices *= in.dim_size(dim); + output_shape.AddDim(in.dim_size(dim)); + } + for (int dim = 0; dim < output_matrix_shape.dims(); ++dim) { + output_shape.AddDim(output_matrix_shape.dim_size(dim)); + } + } + + Tensor* out = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &out)); + + auto shard = [this, &in, &input_matrix_shape, &output_matrix_shape, context, + out](int64 begin, int64 end) { + for (int64 i = begin; i < end; ++i) { + ComputeMatrix(context, i, in, input_matrix_shape, out, + output_matrix_shape); + } + }; + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, num_matrices, + GetCostPerUnit(input_matrix_shape), shard); +} + +template +void LinearAlgebraOp::ComputeMatrix( + OpKernelContext* context, int64 matrix_index, const Tensor& in, + const TensorShape& input_matrix_shape, Tensor* out, + const TensorShape& output_matrix_shape) { + // TODO(kalakris): Handle alignment if possible. Eigen::Map is + // unaligned by default. + ConstMatrixMap input(in.flat().data() + + matrix_index * input_matrix_shape.num_elements(), + input_matrix_shape.dim_size(0), + input_matrix_shape.dim_size(1)); + + // The output matrix shape may not be a matrix. + int num_output_rows = + output_matrix_shape.dims() >= 1 ? output_matrix_shape.dim_size(0) : 1; + int num_output_cols = + output_matrix_shape.dims() == 2 ? output_matrix_shape.dim_size(1) : 1; + MatrixMap output(out->flat().data() + + matrix_index * output_matrix_shape.num_elements(), + num_output_rows, num_output_cols); + ComputeMatrix(context, input, &output); +} + +// Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use. +template class LinearAlgebraOp; +template class LinearAlgebraOp; +template class LinearAlgebraOp; +template class LinearAlgebraOp; + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/linalg_ops_common.h b/tensorflow/core/kernels/linalg_ops_common.h new file mode 100644 index 0000000000..471f11e25f --- /dev/null +++ b/tensorflow/core/kernels/linalg_ops_common.h @@ -0,0 +1,123 @@ +#ifndef TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_ +#define TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/util/work_sharder.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +// A base class to support linear algebra functionality, similar to the +// numpy.linalg module. Supports batch computation on several matrices at once, +// sharding the computations across different threads if necessary. +// +// TODO(kalakris): This needs to be expanded to support binary inputs, and +// multiple outputs. +class LinearAlgebraOpBase : public OpKernel { + public: + explicit LinearAlgebraOpBase(OpKernelConstruction* context) + : OpKernel(context) {} + ~LinearAlgebraOpBase() override {} + + // Return the expected rank of the input. + // TODO(kalakris): This should be a virtual function to support vector inputs. + int GetInputMatrixRank() { return 2; } + + // Return the output shape of each individual matrix operation. Must be + // rank 0, 1, or 2. Scalar outputs are rank 0. + virtual TensorShape GetOutputMatrixShape( + const TensorShape& input_matrix_shape) = 0; + + // Return the cost per matrix operation. Cost per unit is assumed to be + // roughly 1ns, based on comments in core/util/work_sharder.cc. + virtual int64 GetCostPerUnit(const TensorShape& input_matrix_shape) = 0; + + // If SupportsBatchOperation() returns false, this Op will only accept rank 2 + // (if the supported input type is a matrix). If it returns true, the Op will + // accept inputs of rank >= 3, and repeatedly execute the operation on all + // matrices in the innermost two dimensions. + virtual bool SupportsBatchOperation() = 0; + + // Perform the actual computation on an input matrix, and store the results + // in the output. This will be called repeatedly for a single call to + // Compute(), if multiple matrices exist in the input Tensor. + // + // This function should only compute the results for a single input matrix. + // The 'matrix_index' parameter specifies the index of the matrix to be used + // from the input, and the index of the matrix to be written to in the output. + // The input matrix is in row major order, and is located at the memory + // address + // in.flat().data() + + // matrix_index * input_matrix_shape.num_elements(). + // The output matrix is in row major order, and is located at the memory + // address + // out->flat().data() + + // matrix_index * output_matrix_shape.num_elements(). + // The LinearAlgebraOp class below has functionality which performs + // this mapping and presents an interface based on the Eigen::MatrixBase API. + virtual void ComputeMatrix(OpKernelContext* context, int64 matrix_index, + const Tensor& in, + const TensorShape& input_matrix_shape, Tensor* out, + const TensorShape& output_matrix_shape) = 0; + + void Compute(OpKernelContext* context) override; +}; + +// A base class for linear algebra ops templated on the scalar type. +// +// This base class encapsulates the functionality of mapping the input and +// output tensors using Eigen::Map, so that the Eigen::MatrixBase API may be +// directly used by derived classes. +// SupportsBatchOperationT is a bool template argument which if set to true +// will allow the Op to process batches of matrices (rank >= 3); if set to +// false the Op will only accept rank 2 inputs. +template +class LinearAlgebraOp : public LinearAlgebraOpBase { + public: + explicit LinearAlgebraOp(OpKernelConstruction* context) + : LinearAlgebraOpBase(context) {} + + using ConstMatrixMap = + Eigen::Map>; + using MatrixMap = Eigen::Map< + Eigen::Matrix>; + + // Perform the actual computation on the input matrix, and store the results + // in the output. This will be called repeatedly for a single call to + // Compute(), if multiple matrices exist in the input Tensor. + virtual void ComputeMatrix(OpKernelContext* context, + const ConstMatrixMap& input, + MatrixMap* output) = 0; + + bool SupportsBatchOperation() final { return SupportsBatchOperationT; } + + // A concrete implementation of LinearAlgebraOpBase::ComputeMatrix(). + void ComputeMatrix(OpKernelContext* context, int64 matrix_index, + const Tensor& in, const TensorShape& input_matrix_shape, + Tensor* out, const TensorShape& output_matrix_shape) final; +}; + +// Declare that LinearAlgebraOp is explicitly instantiated in +// linalg_ops_common.cc for float and double. +extern template class LinearAlgebraOp; +extern template class LinearAlgebraOp; +extern template class LinearAlgebraOp; +extern template class LinearAlgebraOp; + +} // namespace tensorflow + +#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \ + REGISTER_KERNEL_BUILDER( \ + Name(OpName).Device(DEVICE_CPU).TypeConstraint("T"), OpClass) + +#endif // TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_ diff --git a/tensorflow/core/kernels/listdiff_op.cc b/tensorflow/core/kernels/listdiff_op.cc new file mode 100644 index 0000000000..f490f5ddd3 --- /dev/null +++ b/tensorflow/core/kernels/listdiff_op.cc @@ -0,0 +1,75 @@ +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +template +class ListDiffOp : public OpKernel { + public: + explicit ListDiffOp(OpKernelConstruction* context) : OpKernel(context) { + const DataType dt = DataTypeToEnum::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, DT_INT32})); + } + + void Compute(OpKernelContext* context) override { + const Tensor& x = context->input(0); + const Tensor& y = context->input(1); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(x.shape()), + errors::InvalidArgument("x should be a 1D vector.")); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(y.shape()), + errors::InvalidArgument("y should be a 1D vector.")); + + std::unordered_set y_set; + const auto Ty = y.vec(); + const int y_size = Ty.size(); + y_set.reserve(y_size); + for (int i = 0; i < y_size; ++i) { + y_set.insert(Ty(i)); + } + + // Compute the size of the output. + const auto Tx = x.vec(); + const int x_size = Tx.size(); + + int out_size = 0; + for (int i = 0; i < x_size; ++i) { + if (y_set.count(Tx(i)) == 0) { + ++out_size; + } + } + + // Allocate and populate outputs. + Tensor* out = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, {out_size}, &out)); + auto Tout = out->vec(); + + Tensor* indices = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices)); + auto Tindices = indices->vec(); + + for (int i = 0, p = 0; i < x_size; ++i) { + if (y_set.count(Tx(i)) == 0) { + Tout(p) = Tx(i); + Tindices(p) = i; + p++; + } + } + } +}; + +#define REGISTER_LISTDIFF(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ListDiff").Device(DEVICE_CPU).TypeConstraint("T"), \ + ListDiffOp) + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF); +#undef REGISTER_LISTDIFF + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc new file mode 100644 index 0000000000..ec84145f75 --- /dev/null +++ b/tensorflow/core/kernels/logging_ops.cc @@ -0,0 +1,77 @@ +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +class AssertOp : public OpKernel { + public: + explicit AssertOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& cond = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(cond.shape()), + errors::InvalidArgument("In[0] should be a scalar: ", + cond.shape().ShortDebugString())); + + if (cond.scalar()()) { + return; + } + string msg = "assertion failed: "; + for (int i = 1; i < ctx->num_inputs(); ++i) { + strings::StrAppend(&msg, "[", ctx->input(i).SummarizeValue(summarize_), + "]"); + if (i < ctx->num_inputs() - 1) strings::StrAppend(&msg, " "); + } + ctx->SetStatus(errors::InvalidArgument(msg)); + } + + private: + int32 summarize_ = 0; +}; + +REGISTER_KERNEL_BUILDER(Name("Assert").Device(DEVICE_CPU), AssertOp); + +class PrintOp : public OpKernel { + public: + explicit PrintOp(OpKernelConstruction* ctx) + : OpKernel(ctx), call_counter_(0) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("message", &message_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("first_n", &first_n_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_)); + } + + void Compute(OpKernelContext* ctx) override { + if (IsRefType(ctx->input_dtype(0))) { + ctx->forward_ref_input_to_ref_output(0, 0); + } else { + ctx->set_output(0, ctx->input(0)); + } + if (first_n_ >= 0) { + mutex_lock l(mu_); + if (call_counter_ >= first_n_) return; + call_counter_++; + } + string msg; + strings::StrAppend(&msg, message_); + for (int i = 1; i < ctx->num_inputs(); ++i) { + strings::StrAppend(&msg, "[", ctx->input(i).SummarizeValue(summarize_), + "]"); + } + LOG(INFO) << msg; + } + + private: + mutex mu_; + int64 call_counter_ GUARDED_BY(mu_) = 0; + int64 first_n_ = 0; + int32 summarize_ = 0; + string message_; +}; + +REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc new file mode 100644 index 0000000000..a7af6eb303 --- /dev/null +++ b/tensorflow/core/kernels/logging_ops_test.cc @@ -0,0 +1,87 @@ +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace { + +class PrintingGraphTest : public OpsTestBase { + protected: + Status Init(DataType input_type1, DataType input_type2, string msg = "", + int first_n = -1, int summarize = 3) { + RequireDefaultOps(); + TF_CHECK_OK(NodeDefBuilder("op", "Print") + .Input(FakeInput(input_type1)) + .Input(FakeInput(2, input_type2)) + .Attr("message", msg) + .Attr("first_n", first_n) + .Attr("summarize", summarize) + .Finalize(node_def())); + return InitOp(); + } +}; + +TEST_F(PrintingGraphTest, Int32Success_6) { + ASSERT_OK(Init(DT_INT32, DT_INT32)); + AddInputFromArray(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + AddInputFromArray(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + AddInputFromArray(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({6})); + test::FillValues(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(PrintingGraphTest, Int32Success_Summarize6) { + ASSERT_OK(Init(DT_INT32, DT_INT32, "", -1, 6)); + AddInputFromArray(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + AddInputFromArray(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + AddInputFromArray(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({6})); + test::FillValues(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(PrintingGraphTest, StringSuccess) { + ASSERT_OK(Init(DT_INT32, DT_STRING)); + AddInputFromArray(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + AddInputFromArray(TensorShape({}), {"foo"}); + AddInputFromArray(TensorShape({}), {"bar"}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({6})); + test::FillValues(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(PrintingGraphTest, MsgSuccess) { + ASSERT_OK(Init(DT_INT32, DT_STRING, "Message: ")); + AddInputFromArray(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + AddInputFromArray(TensorShape({}), {"foo"}); + AddInputFromArray(TensorShape({}), {"bar"}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({6})); + test::FillValues(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(PrintingGraphTest, FirstNSuccess) { + ASSERT_OK(Init(DT_INT32, DT_STRING, "", 3)); + AddInputFromArray(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + AddInputFromArray(TensorShape({}), {"foo"}); + AddInputFromArray(TensorShape({}), {"bar"}); + // run 4 times but we only print 3 as intended + for (int i = 0; i < 4; i++) ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({6})); + test::FillValues(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +} // end namespace +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc new file mode 100644 index 0000000000..9781bcfa59 --- /dev/null +++ b/tensorflow/core/kernels/lookup_table_init_op.cc @@ -0,0 +1,116 @@ +#define EIGEN_USE_THREADS + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/initializable_lookup_table.h" +#include "tensorflow/core/kernels/lookup_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace lookup { + +// Iterator to initialize tables given 'keys' and 'values' tensors. +// +// The two tensors are returned in the first iteration. It doesn't loop +// over each element of the tensor since insertions in the lookup table can +// process batches. +class KeyValueTensorIterator + : public InitializableLookupTable::InitTableIterator { + public: + // keys and values are not owned by the iterator. + explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values) + : keys_(keys), values_(values), valid_(true), status_(Status::OK()) { + TensorShape key_shape = keys_->shape(); + if (!key_shape.IsSameSize(values_->shape())) { + valid_ = false; + status_ = errors::InvalidArgument( + "keys and values should have the same dimension.", + key_shape.DebugString(), " vs ", values_->shape().DebugString()); + } + if (key_shape.num_elements() == 0) { + valid_ = false; + status_ = + errors::InvalidArgument("keys and values cannot be empty tensors."); + } + } + + bool Valid() const override { return valid_; } + + void Next() override { + valid_ = false; + status_ = errors::OutOfRange("No more data."); + } + + const Tensor& keys() const override { return *keys_; } + + const Tensor& values() const override { return *values_; } + + Status status() const override { return status_; } + + int64 total_size() const { + return keys_ == nullptr ? -1 : keys_->NumElements(); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator); + + const Tensor* keys_; // Doesn't own it. + const Tensor* values_; // Doesn't own it. + bool valid_; // true if the iterator points to an existing range. + Status status_; +}; + +} // namespace lookup + +// Kernel to initialize a look table given a key and value tensors. +// After this operation, the table becomes read-only. +class InitializeTableOp : public OpKernel { + public: + explicit InitializeTableOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + mutex_lock l(mu_); + lookup::InitializableLookupTable* table; + OP_REQUIRES_OK(ctx, + GetInitializableLookupTable("table_handle", ctx, &table)); + core::ScopedUnref unref_me(table); + + DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(), + table->value_dtype()}; + DataTypeVector expected_outputs = {}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); + + const Tensor& keys = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(keys.shape()), + errors::InvalidArgument("Keys must be a vector, but received ", + keys.shape().DebugString())); + + const Tensor& values = ctx->input(2); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(values.shape()), + errors::InvalidArgument("Values must be a vector, but received ", + values.shape().DebugString())); + + OP_REQUIRES(ctx, keys.NumElements() == values.NumElements(), + errors::InvalidArgument( + "Keys and values must have the same size ", + keys.NumElements(), " vs ", values.NumElements())); + + lookup::KeyValueTensorIterator iter(&keys, &values); + OP_REQUIRES_OK(ctx, table->Initialize(iter)); + } + + private: + mutex mu_; +}; + +REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU), + InitializeTableOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc new file mode 100644 index 0000000000..2bab4df94f --- /dev/null +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -0,0 +1,166 @@ +#include "tensorflow/core/kernels/lookup_table_op.h" +#define EIGEN_USE_THREADS + +#include +#include + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/initializable_lookup_table.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { +namespace lookup { + +// Lookup table that wraps an unordered_map, where the key and value data type +// is specified. +// +// This table is recommened for any variations to key values. +// +// For look up, the table is required to be initialized (allocated +// and populated). Once the table is marked as initialized it becomes read-only. +// +// Sample use case: +// +// HashTable table; // int64 -> int64. +// table.Prepare(10); // Prepare the underlying data structure, the number of +// // elements is required by interface, but not used. +// // Populate the table, elements could be added in one or multiple calls. +// table.Insert(key_tensor, value_tensor); // Populate the table. +// ... +// table.set_is_initialized(); +// +// table.Find(in_t, &out_t, default_t) +// +template +class HashTable : public InitializableLookupTable { + public: + size_t size() const override { return table_ ? table_->size() : 0; } + + DataType key_dtype() const override { return DataTypeToEnum::v(); } + + DataType value_dtype() const override { return DataTypeToEnum::v(); } + + protected: + Status DoPrepare(size_t unused) override { + if (is_initialized_) { + return errors::Aborted("HashTable already initialized."); + } + if (!table_) { + table_ = std::unique_ptr>( + new std::unordered_map()); + } + return Status::OK(); + }; + + Status DoInsert(const Tensor& keys, const Tensor& values) override { + if (!table_) { + return errors::FailedPrecondition("HashTable is not prepared."); + } + + const auto key_values = keys.flat(); + const auto value_values = values.flat(); + for (size_t i = 0; i < key_values.size(); ++i) { + const K& key = key_values(i); + const V& value = value_values(i); + const V& previous_value = gtl::LookupOrInsert(table_.get(), key, value); + if (previous_value != value) { + return errors::FailedPrecondition( + "HashTable has different value for same key. Key ", key, " has ", + previous_value, " and trying to add value ", value); + } + } + return Status::OK(); + } + + Status DoFind(const Tensor& key, Tensor* value, + const Tensor& default_value) override { + const V default_val = default_value.flat()(0); + const auto key_values = key.flat(); + auto value_values = value->flat(); + + for (size_t i = 0; i < key_values.size(); ++i) { + value_values(i) = + gtl::FindWithDefault(*table_, key_values(i), default_val); + } + return Status::OK(); + } + + private: + std::unique_ptr> table_; +}; + +} // namespace lookup + +// Table lookup op. Perform the lookup operation on the given table. +class LookupTableFindOp : public OpKernel { + public: + explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + lookup::LookupInterface* table; + OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); + core::ScopedUnref unref_me(table); + + DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(), + table->value_dtype()}; + DataTypeVector expected_outputs = {table->value_dtype()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); + + const Tensor& input = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input.shape()), + errors::InvalidArgument("Input must be a vector, not ", + input.shape().DebugString())); + + const Tensor& default_value = ctx->input(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(default_value.shape()), + errors::InvalidArgument("Default value must be a scalar, not ", + default_value.shape().DebugString())); + + Tensor* out; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("output_values", input.shape(), &out)); + + OP_REQUIRES_OK(ctx, table->Find(input, out, default_value)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU), + LookupTableFindOp); + +// Op that returns the size of the given table. +class LookupTableSizeOp : public OpKernel { + public: + explicit LookupTableSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + lookup::LookupInterface* table; + OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); + core::ScopedUnref unref_me(table); + + Tensor* out; + OP_REQUIRES_OK(ctx, ctx->allocate_output("size", TensorShape({}), &out)); + out->flat().setConstant(table->size()); + } +}; + +REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU), + LookupTableSizeOp); + +// Register the HashTable op with the currently supported key and value types. +#define REGISTER_KERNEL(key_dtype, value_dtype) \ + REGISTER_KERNEL_BUILDER( \ + Name("HashTable") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + LookupTableOp, key_dtype, \ + value_dtype>) + +REGISTER_KERNEL(string, int64); +REGISTER_KERNEL(int64, string); + +#undef REGISTER_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h new file mode 100644 index 0000000000..dc53ce33a6 --- /dev/null +++ b/tensorflow/core/kernels/lookup_table_op.h @@ -0,0 +1,80 @@ +#ifndef TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_ +#define TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_ + +#include "tensorflow/core/framework/lookup_interface.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/kernels/lookup_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +// Lookup table op that supports different table implementations specified by +// the 'Container' template. Container must be derived from LookupInterface. The +// key and value are of the templated type "key_dtype" and "value_dtype" +// respectively. +template +class LookupTableOp : public OpKernel { + public: + // ctx is not owned by this class. + explicit LookupTableOp(OpKernelConstruction* ctx) + : OpKernel(ctx), table_handle_set_(false) { + OP_REQUIRES_OK(ctx, ctx->allocate_persistent(tensorflow::DT_STRING, + tensorflow::TensorShape({2}), + &table_handle_, nullptr)); + } + + // ctx is not owned by this function. + void Compute(OpKernelContext* ctx) override { + mutex_lock l(mu_); + if (!table_handle_set_) { + OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); + auto creator = [this](lookup::LookupInterface** ret) { + *ret = new Container(); + return Status::OK(); + }; + + lookup::LookupInterface* table = nullptr; + OP_REQUIRES_OK( + ctx, cinfo_.resource_manager() + ->template LookupOrCreate( + cinfo_.container(), cinfo_.name(), &table, creator)); + core::ScopedUnref unref_me(table); + + OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes( + *table, DataTypeToEnum::v(), + DataTypeToEnum::v(), cinfo_.name())); + + auto h = table_handle_.AccessTensor(ctx)->template flat(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + table_handle_set_ = true; + } + ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx)); + } + + ~LookupTableOp() override { + // If the table object was not shared, delete it. + if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) { + TF_CHECK_OK( + cinfo_.resource_manager()->template Delete( + cinfo_.container(), cinfo_.name())); + } + } + + private: + mutex mu_; + PersistentTensor table_handle_ GUARDED_BY(mu_); + bool table_handle_set_ GUARDED_BY(mu_); + ContainerInfo cinfo_; + + TF_DISALLOW_COPY_AND_ASSIGN(LookupTableOp); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_ diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc new file mode 100644 index 0000000000..634c11e4a5 --- /dev/null +++ b/tensorflow/core/kernels/lookup_util.cc @@ -0,0 +1,72 @@ +#include "tensorflow/core/kernels/lookup_util.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { +namespace lookup { +namespace { + +Status GetTableHandle(const string& input_name, OpKernelContext* ctx, + string* container, string* table_handle) { + { + mutex* mu; + TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu)); + mutex_lock l(*mu); + Tensor tensor; + TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true)); + if (tensor.NumElements() != 2) { + return errors::InvalidArgument( + "Lookup table handle must be scalar, but had shape: ", + tensor.shape().DebugString()); + } + auto h = tensor.flat(); + *container = h(0); + *table_handle = h(1); + } + return Status::OK(); +} + +} // namespace + +Status GetLookupTable(const string& input_name, OpKernelContext* ctx, + LookupInterface** table) { + string container; + string table_handle; + TF_RETURN_IF_ERROR( + GetTableHandle(input_name, ctx, &container, &table_handle)); + return ctx->resource_manager()->Lookup(container, table_handle, table); +} + +Status GetInitializableLookupTable(const string& input_name, + OpKernelContext* ctx, + InitializableLookupTable** table) { + string container; + string table_handle; + TF_RETURN_IF_ERROR( + GetTableHandle(input_name, ctx, &container, &table_handle)); + LookupInterface* lookup_table; + TF_RETURN_IF_ERROR( + ctx->resource_manager()->Lookup(container, table_handle, &lookup_table)); + *table = dynamic_cast(lookup_table); + if (*table == nullptr) { + lookup_table->Unref(); + return errors::InvalidArgument("Table ", container, " ", table_handle, + " is not initializable"); + } + return Status::OK(); +} + +Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype, + DataType value_dtype, const string& table_name) { + if (table.key_dtype() != key_dtype || table.value_dtype() != value_dtype) { + return errors::InvalidArgument( + "Conflicting key/value dtypes ", key_dtype, "->", value_dtype, " with ", + table.key_dtype(), "-", table.value_dtype(), " for table ", table_name); + } + return Status::OK(); +} + +} // namespace lookup +} // namespace tensorflow diff --git a/tensorflow/core/kernels/lookup_util.h b/tensorflow/core/kernels/lookup_util.h new file mode 100644 index 0000000000..991a757edd --- /dev/null +++ b/tensorflow/core/kernels/lookup_util.h @@ -0,0 +1,31 @@ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_ + +#include "tensorflow/core/framework/lookup_interface.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/initializable_lookup_table.h" + +namespace tensorflow { +namespace lookup { + +// Gets the LookupTable stored in the ctx->resource_manager() with key +// passed by attribute with name input_name, returns null if the table +// doesn't exist. +Status GetLookupTable(const string& input_name, OpKernelContext* ctx, + LookupInterface** table); + +// Gets the InitializableLookupTable stored in the +// ctx->resource_manager() with key passed by attribute with name +// input_name, returns null if the table doesn't exist. +Status GetInitializableLookupTable(const string& input_name, + OpKernelContext* ctx, + InitializableLookupTable** table); + +// Verify that the given key_dtype and value_dtype matches the corresponding +// table's data types. +Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype, + DataType value_dtype, const string& table_name); +} // namespace lookup +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_ diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc new file mode 100644 index 0000000000..e5abf5906f --- /dev/null +++ b/tensorflow/core/kernels/lrn_op.cc @@ -0,0 +1,228 @@ +// LRN = Local Response Normalization +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#ifndef __ANDROID__ +#include "tensorflow/core/util/work_sharder.h" +#endif + +namespace tensorflow { + +// Create a depth-by-depth band matrix with 1s along a swath of size (2 * +// depth_radius + 1) around the diagonal. +static void GetBandMatrix(int depth, int64 depth_radius, + Eigen::Tensor* result) { + result->setZero(); + for (int row = 0; row < depth; ++row) { + const int begin = std::max(0, row - depth_radius); + const int end = std::min(depth, row + depth_radius + 1); + Eigen::DSizes start(row, begin); + Eigen::DSizes sizes(1, end - begin); + result->slice(start, sizes).setConstant(1.0f); + } +} + +class LRNOp : public OpKernel { + public: + explicit LRNOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius_)); + OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_)); + OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& in = context->input(0); + OP_REQUIRES(context, in.dims() == 4, + errors::InvalidArgument("in must be 4-dimensional")); + const int64 batch = in.dim_size(0); + const int64 rows = in.dim_size(1); + const int64 cols = in.dim_size(2); + const int64 depth = in.dim_size(3); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({batch, rows, cols, depth}), &output)); + +#ifdef __ANDROID__ + MognetLRN(in, batch, rows, cols, depth, output); +#else + const int nodes = cols * rows; + auto in_shaped = in.shaped({nodes * batch, depth}); + + // Multiplying the input with the band matrix has the effect of reducing the + // correct patch along the depth. + Eigen::Tensor multiplier(depth, depth); + GetBandMatrix(depth, depth_radius_, &multiplier); + + auto out_shaped = output->shaped({nodes * batch, depth}); + Eigen::array dims = {{DimPair(1, 0)}}; + /// TODO(keveman): Optimize for beta in {0, 1, 0.5} + out_shaped.device(context->eigen_cpu_device()) = + in_shaped / + in_shaped.square() + .contract(multiplier, dims) + .unaryExpr([this](float x) { return bias_ + alpha_ * x; }) + .pow(beta_); +#endif + } + + private: + typedef Eigen::Tensor::DimensionPair DimPair; + + void MognetLRN(const Tensor& in, const int batch, const int rows, + const int cols, const int depth, Tensor* out) { + Eigen::Map> + data_in(in.flat().data(), depth, batch * rows * cols); + + Eigen::Map> data_out( + out->flat().data(), depth, batch * rows * cols); + + const int double_depth_radius = depth_radius_ * 2; + Eigen::VectorXf padded_square(data_in.rows() + double_depth_radius); + padded_square.setZero(); + for (int r = 0; r < data_in.cols(); ++r) { + // Do local response normalization for data_in(:, r) + // first, compute the square and store them in buffer for repeated use + padded_square.block(depth_radius_, 0, data_out.rows(), 1) = + data_in.col(r).cwiseProduct(data_in.col(r)) * alpha_; + // Then, compute the scale and writes them to data_out + float accumulated_scale = 0; + for (int i = 0; i < double_depth_radius; ++i) { + accumulated_scale += padded_square(i); + } + for (int i = 0; i < data_in.rows(); ++i) { + accumulated_scale += padded_square(i + double_depth_radius); + data_out(i, r) = bias_ + accumulated_scale; + accumulated_scale -= padded_square(i); + } + } + + // In a few cases, the pow computation could benefit from speedups. + if (beta_ == 1) { + data_out.array() = data_in.array() * data_out.array().inverse(); + } else if (beta_ == 0.5) { + data_out.array() = data_in.array() * data_out.array().sqrt().inverse(); + } else { + data_out.array() = data_in.array() * data_out.array().pow(-beta_); + } + } + + int64 depth_radius_; + float bias_; + float alpha_; + float beta_; +}; + +REGISTER_KERNEL_BUILDER(Name("LRN").Device(DEVICE_CPU), LRNOp); + +#ifndef __ANDROID__ + +class LRNGradOp : public OpKernel { + public: + explicit LRNGradOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius_)); + OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_)); + OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& in_grads = context->input(0); + const Tensor& in_image = context->input(1); + const Tensor& out_image = context->input(2); + + OP_REQUIRES(context, in_grads.dims() == 4 && in_image.dims() == 4, + errors::InvalidArgument("inputs must be 4-dimensional")); + const int64 batch = in_grads.dim_size(0); + const int64 rows = in_grads.dim_size(1); + const int64 cols = in_grads.dim_size(2); + const int64 depth = in_grads.dim_size(3); + OP_REQUIRES( + context, + in_image.dim_size(0) == batch && in_image.dim_size(1) == rows && + in_image.dim_size(2) == cols && in_image.dim_size(3) == depth && + out_image.dim_size(0) == batch && out_image.dim_size(1) == rows && + out_image.dim_size(2) == cols && out_image.dim_size(3) == depth, + errors::InvalidArgument( + "input_grads, input_image, and out_image should have the same " + "shape")); + const auto nodes = cols * rows; + auto grads_shaped = in_grads.shaped({nodes * batch, depth}); + auto in_shaped = in_image.shaped({nodes * batch, depth}); + auto activations = out_image.shaped({nodes * batch, depth}); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({batch, rows, cols, depth}), &output)); + auto out_shaped = output->shaped({nodes * batch, depth}); + out_shaped.setZero(); + + auto shard = [this, activations, in_shaped, grads_shaped, out_shaped, + depth](int64 begin, int64 end) { + for (int64 i = begin; i < end; ++i) { + for (int64 j = 0; j < depth; ++j) { + // Let y be the LRN activations and x be the inputs along the depth + // dimension. (LRN operates independently along rows, cols, and + // batch). + // We have + // yi = xi / (bias + alpha(sum_j_{i - depth_radius}^{i + depth_radius} + // x_j^2))^beta + // + // Let N = (bias + alpha(sum_j_{i - depth_radius}^{i + depth_radius} + // x_j^2)) + // dy_i/dx_i = (N^beta - xi. beta*N^(beta-1)*2*alpha*xi)/N^(2*beta) + // dy_i/dx_j = ( - xi. beta*N^(beta-1)*2*alpha*xj)/N^(2*beta) + // + // NOTE(keveman) : We can compute N by doing (yi/xi) ^ (1/beta). + // However, this is numerically unstable for small values of xi. We + // compute N explicitly here to avoid that. + + int64 depth_begin = std::max(0, j - depth_radius_); + int64 depth_end = std::min(depth, j + depth_radius_ + 1); + + float norm = 0.0f; + for (int64 k = depth_begin; k < depth_end; ++k) { + norm += in_shaped(i, k) * in_shaped(i, k); + } + norm = alpha_ * norm + bias_; + DCHECK_GT(norm, 1e-6); + for (int64 k = depth_begin; k < depth_end; ++k) { + float dyi = -2.0f * alpha_ * beta_ * in_shaped(i, k) * + activations(i, j) / norm; + if (k == j) { + dyi += std::pow(norm, -beta_); + } + dyi *= grads_shaped(i, j); + const_cast::Tensor&>(out_shaped)(i, k) += dyi; + } + } + } + }; + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch, + depth * depth, shard); + } + + private: + typedef Eigen::Tensor::DimensionPair DimPair; + + int64 depth_radius_; + float bias_; + float alpha_; + float beta_; +}; + +REGISTER_KERNEL_BUILDER(Name("LRNGrad").Device(DEVICE_CPU), LRNGradOp); + +#endif // __ANDROID__ + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/lrn_op_test.cc b/tensorflow/core/kernels/lrn_op_test.cc new file mode 100644 index 0000000000..4c338b6cb3 --- /dev/null +++ b/tensorflow/core/kernels/lrn_op_test.cc @@ -0,0 +1,185 @@ +#include +#include +#include + +#include +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +static const float tol_ = 1e-4; + +class LRNFloatTest : public OpsTestBase { + protected: + LRNFloatTest() : philox_(123, 17), rand_(&philox_) { RequireDefaultOps(); } + + int GetIntAttr(const string& name) { + int value; + TF_CHECK_OK(GetNodeAttr(*node_def(), name, &value)); + return value; + } + + float GetFloatAttr(const string& name) { + float value; + TF_CHECK_OK(GetNodeAttr(*node_def(), name, &value)); + return value; + } + + bool Compare() { + const auto& input = GetInput(0); + const int64 batch_size = input.dim_size(0); + const int64 rows = input.dim_size(1); + const int64 cols = input.dim_size(2); + const int64 depth = input.dim_size(3); + const int64 rest = cols * rows * batch_size; + + const int64 depth_radius = GetIntAttr("depth_radius"); + const float bias = GetFloatAttr("bias"); + const float alpha = GetFloatAttr("alpha"); + const float beta = GetFloatAttr("beta"); + + Eigen::Tensor expected(batch_size, rows, cols, + depth); + auto out = expected.reshape(Eigen::DSizes{rest, depth}); + auto in = input.shaped({rest, depth}); + + for (int64 i = 0; i < rest; ++i) { + Eigen::Tensor out_col(depth); + for (int64 d = 0; d < depth; ++d) { + float denom = 0.0f; + for (int64 r = std::max(0ll, d - depth_radius); + r < std::min(depth, d + depth_radius + 1); ++r) { + denom += in(i, r) * in(i, r); + } + denom = std::pow(denom * alpha + bias, beta); + out_col(d) = in(i, d) / denom; + } + out.chip<0>(i) = out_col; + } + auto actual = GetOutput(0)->tensor(); + Eigen::Tensor sum = + ((expected - actual).abs() > actual.constant(tol_)) + .select(actual.constant(1), actual.constant(0)) + .sum(); + return sum() == 0; + } + + random::PhiloxRandom philox_; + random::SimplePhilox rand_; +}; + +TEST_F(LRNFloatTest, Depth96) { + ASSERT_OK(NodeDefBuilder("lrn_op", "LRN") + .Input(FakeInput()) + .Attr("depth_radius", 5) + .Attr("bias", 1.0f) + .Attr("alpha", 0.1f) + .Attr("beta", 2.0f) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + AddInput(TensorShape({1, 1, 1, 96}), + [this](int i) -> float { return i + 1; }); + ASSERT_OK(RunOpKernel()); + auto actual = GetOutput(0)->tensor(); + + // Output for Node 0 with Value 1: + // 1 / (1 + 0.1*(1^2 + 2^2 + 3^2 + 4^2 + 5^2 + 6^2))^2 + EXPECT_NEAR(1. / (10.1 * 10.1), actual(0, 0, 0, 0), tol_); + + // Output for Node 5 with Value 6: + // 6 / (1 + 0.1*(1^2 + 2^2 + 3^2 + 4^2 + 5^2 + 6^2 ... + 11^2))^2 + EXPECT_NEAR(6. / (51.6 * 51.6), actual(0, 0, 0, 5), tol_); + + // Output for Node 63 with value 64: + // 64 / (1 + 0.1*(59^2 + 60^2 + 61^2 + 62^2 + 63^2 + 64^2))^2 + EXPECT_NEAR(64. / (2272.1 * 2272.1), actual(0, 0, 0, 63), tol_); + + // Output for Node 64 with value 65: + // 65 / (1 + 0.1*(65^2 + 66^2 + 67^2 + 68^2 + 69^2 + 70^2))^2 + EXPECT_NEAR(65. / (2736.5 * 2736.5), actual(0, 0, 0, 64), tol_); + + // Output for Node 95 with value 96: + // 96 / (1 + 0.1*(91^2 + 92^2 + 93^2 + 94^2 + 95^2 + 96^2))^2 + EXPECT_NEAR(96. / (5248.1 * 5248.1), actual(0, 0, 0, 95), tol_); + EXPECT_TRUE(Compare()); +} + +TEST_F(LRNFloatTest, Depth16) { + ASSERT_OK(NodeDefBuilder("lrn_op", "LRN") + .Input(FakeInput()) + .Attr("depth_radius", 5) + .Attr("bias", 1.0f) + .Attr("alpha", 0.1f) + .Attr("beta", 2.0f) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + AddInput(TensorShape({1, 1, 1, 16}), + [this](int i) -> float { return i + 1; }); + ASSERT_OK(RunOpKernel()); + auto actual = GetOutput(0)->tensor(); + + // Output for Node 0 with Value 1: + // 1 / (1 + 0.1*(1^2 + 2^2 + 3^2 + 4^2 + 5^2 + 6^2))^2 + EXPECT_NEAR(1. / (10.1 * 10.1), actual(0, 0, 0, 0), tol_); + + // Output for Node 5 with Value 6: + // 6 / (1 + 0.1*(1^2 + 2^2 + 3^2 + 4^2 + 5^2 + 6^2 ... + 11^2))^2 + EXPECT_NEAR(6. / (51.6 * 51.6), actual(0, 0, 0, 5), tol_); + + // Output for Node 15 with value 16: + // 16 / (1 + 0.1*(11^2 + 12^2 + 13^2 + 14^2 + 15^2 + 16^2))^2 + EXPECT_NEAR(16. / (112.1 * 112.1), actual(0, 0, 0, 15), tol_); + EXPECT_TRUE(Compare()); +} + +static double RndGaussian(random::SimplePhilox* rnd) { + // Box-Muller transformation. + // See, for example, http://www.taygeta.com/random/gaussian.html + double x1, x2; + double r; + do { + x1 = 2 * rnd->RandDouble() - 1; + x2 = 2 * rnd->RandDouble() - 1; + r = x1 * x1 + x2 * x2; + } while (r == 0 || r >= 1.0); + double w = sqrt(-2.0 * log(r) / r); + return x1 * w; +} + +#define TCASE(NAME, DEPTH, BATCH, DEPTH_RADIUS, BIAS, ALPHA, BETA) \ + TEST_F(LRNFloatTest, NAME) { \ + ASSERT_OK(NodeDefBuilder("lrn_op", "LRN") \ + .Input(FakeInput()) \ + .Attr("depth_radius", (DEPTH_RADIUS)) \ + .Attr("bias", (BIAS)) \ + .Attr("alpha", ((ALPHA) / 10)) \ + .Attr("beta", (BETA)) \ + .Finalize(node_def())); \ + ASSERT_OK(InitOp()); \ + AddInput(TensorShape({BATCH, 1, 1, DEPTH}), \ + [this](int i) -> float { return RndGaussian(&rand_); }); \ + ASSERT_OK(RunOpKernel()); \ + EXPECT_TRUE(Compare()); \ + } + +// clang-format off +// DEPTH BATCH DEPTH_RADIUS BIAS ALPHA BETA +TCASE(T0, 4, 2, 2, 1.0f, 1.0f, 2.0f) +TCASE(T1, 16, 1, 5, 1.0f, 1.0f, 2.0f) +TCASE(T2, 16, 32, 2, 1.0f, 2.0f, 1.0f) +TCASE(T3, 128, 4, 3, 2.0f, 1.0f, 1.0f) +// clang-format on + +#undef TCASE +} // namespace tensorflow diff --git a/tensorflow/core/kernels/matching_files_op.cc b/tensorflow/core/kernels/matching_files_op.cc new file mode 100644 index 0000000000..08a4da5b41 --- /dev/null +++ b/tensorflow/core/kernels/matching_files_op.cc @@ -0,0 +1,42 @@ +// See docs in ../ops/io_ops.cc. + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/match.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +class MatchingFilesOp : public OpKernel { + public: + using OpKernel::OpKernel; + void Compute(OpKernelContext* context) override { + const Tensor* pattern; + OP_REQUIRES_OK(context, context->input("pattern", &pattern)); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(pattern->shape()), + errors::InvalidArgument( + "Input pattern tensor must be scalar, but had shape: ", + pattern->shape().DebugString())); + std::vector fnames; + OP_REQUIRES_OK(context, + io::GetMatchingFiles(context->env(), + pattern->scalar()(), &fnames)); + const int num_out = fnames.size(); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + "filenames", TensorShape({num_out}), &output)); + auto output_vec = output->vec(); + for (int i = 0; i < num_out; ++i) { + output_vec(i) = fnames[i]; + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("MatchingFiles").Device(DEVICE_CPU), + MatchingFilesOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc new file mode 100644 index 0000000000..48bdba78b2 --- /dev/null +++ b/tensorflow/core/kernels/matmul_op.cc @@ -0,0 +1,214 @@ +// See docs in ../ops/math_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/matmul_op.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/fill_functor.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/stream_executor/stream.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +#if GOOGLE_CUDA + +namespace { +template +perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory) { + perftools::gputools::DeviceMemoryBase wrapped(const_cast(cuda_memory)); + perftools::gputools::DeviceMemory typed(wrapped); + return typed; +} +} // namespace + +#endif // GOOGLE_CUDA + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +struct LaunchMatMul; + +// On CPUs, we ignore USE_CUBLAS +template +struct LaunchMatMulCPU { + static void launch( + OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b, + const Eigen::array, 1>& dim_pair, + Tensor* out) { + functor::MatMulFunctor()(ctx->eigen_device(), + out->matrix(), a.matrix(), + b.matrix(), dim_pair); + } +}; + +template +struct LaunchMatMul : public LaunchMatMulCPU {}; + +#if GOOGLE_CUDA + +template +struct LaunchMatMul { + static void launch( + OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b, + const Eigen::array, 1>& dim_pair, + Tensor* out) { + perftools::gputools::blas::Transpose trans[] = { + perftools::gputools::blas::Transpose::kNoTranspose, + perftools::gputools::blas::Transpose::kTranspose}; + const uint64 m = a.dim_size(1 - dim_pair[0].first); + const uint64 k = a.dim_size(dim_pair[0].first); + const uint64 n = b.dim_size(1 - dim_pair[0].second); + bool transpose_a = dim_pair[0].first == 0; + bool transpose_b = dim_pair[0].second == 1; + auto blas_transpose_a = trans[transpose_a]; + auto blas_transpose_b = trans[transpose_b]; + + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + + auto a_ptr = AsDeviceMemory(a.template flat().data()); + auto b_ptr = AsDeviceMemory(b.template flat().data()); + auto c_ptr = AsDeviceMemory(out->template flat().data()); + + // Cublas does + // C = A x B + // where A, B and C are assumed to be in column major. + // We want the output to be in row-major, so we can compute + // C' = B' x A' (' stands for transpose) + bool blas_launch_status = + stream->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, 1.0f, + b_ptr, transpose_b ? k : n, a_ptr, + transpose_a ? m : k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal( + "Blas SGEMM launch failed : a.shape=(", a.dim_size(0), ", ", + a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1), + "), m=", m, ", n=", n, ", k=", k)); + } + } +}; + +template +struct LaunchMatMul { + static void launch( + OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b, + const Eigen::array, 1>& dim_pair, + Tensor* out) { + functor::MatMulFunctor()(ctx->eigen_device(), + out->matrix(), a.matrix(), + b.matrix(), dim_pair); + } +}; + +#endif // GOOGLE_CUDA + +template +class MatMulOp : public OpKernel { + public: + explicit MatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& a = ctx->input(0); + const Tensor& b = ctx->input(1); + + // Check that the dimensions of the two matrices are valid. + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), + errors::InvalidArgument("In[0] is not a matrix")); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), + errors::InvalidArgument("In[1] is not a matrix")); + Eigen::array, 1> dim_pair; + dim_pair[0].first = transpose_a_ ? 0 : 1; + dim_pair[0].second = transpose_b_ ? 1 : 0; + + OP_REQUIRES(ctx, + a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), + errors::InvalidArgument("Matrix size-compatible: In[0]: ", + a.shape().DebugString(), ", In[1]: ", + b.shape().DebugString())); + int a_dim_remaining = 1 - dim_pair[0].first; + int b_dim_remaining = 1 - dim_pair[0].second; + TensorShape out_shape( + {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)}); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); + + if (out->NumElements() == 0) { + // If a has shape [0, x] or b has shape [x, 0], the output shape + // is a 0-element matrix, so there is nothing to do. + return; + } + + if (a.NumElements() == 0 || b.NumElements() == 0) { + // If a has shape [x, 0] and b has shape [0, y], the + // output shape is [x, y] where x and y are non-zero, so we fill + // the output with zeros. + functor::SetZeroFunctor f; + f(ctx->eigen_device(), out->flat()); + return; + } + + LaunchMatMul::launch(ctx, this, a, b, dim_pair, out); + } + + private: + bool transpose_a_; + bool transpose_b_; +}; + +namespace functor { + +// Partial specialization MatMulFunctor. +template +struct MatMulFunctor { + void operator()( + const CPUDevice& d, typename MatMulTypes::out_type out, + typename MatMulTypes::in_type in0, + typename MatMulTypes::in_type in1, + const Eigen::array, 1>& dim_pair) { + MatMul(d, out, in0, in1, dim_pair); + } +}; + +} // end namespace functor + +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ + MatMulOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T").Label("eigen"), \ + MatMulOp) + +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("MatMul").Device(DEVICE_GPU).TypeConstraint("T"), \ + MatMulOp); \ + REGISTER_KERNEL_BUILDER(Name("MatMul") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .Label("cublas"), \ + MatMulOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("MatMul").Device(DEVICE_GPU).TypeConstraint("T").Label("eigen"), \ + MatMulOp) + +REGISTER_CPU(float); +REGISTER_CPU(double); +REGISTER_CPU(int32); +REGISTER_CPU(complex64); +#if GOOGLE_CUDA +REGISTER_GPU(float); +// REGISTER_GPU(double); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/matmul_op.h b/tensorflow/core/kernels/matmul_op.h new file mode 100644 index 0000000000..f75b0ded1b --- /dev/null +++ b/tensorflow/core/kernels/matmul_op.h @@ -0,0 +1,40 @@ +#ifndef TENSORFLOW_KERNELS_MATMUL_OP_H_ +#define TENSORFLOW_KERNELS_MATMUL_OP_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Helpers to define tensor needed by MatMul op. +template +struct MatMulTypes { + typedef Eigen::TensorMap, Eigen::Aligned> + out_type; + typedef Eigen::TensorMap, + Eigen::Aligned> in_type; +}; + +template +void MatMul(const Device& d, Out out, In0 in0, In1 in1, + const DimPair& dim_pair) { + out.device(d) = in0.contract(in1, dim_pair); +} + +template +struct MatMulFunctor { + // Computes on device "d": out = in0 * in1, where * is matrix + // multiplication. + void operator()( + const Device& d, typename MatMulTypes::out_type out, + typename MatMulTypes::in_type in0, + typename MatMulTypes::in_type in1, + const Eigen::array, 1>& dim_pair); +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_KERNELS_MATMUL_OP_H_ diff --git a/tensorflow/core/kernels/matmul_op_gpu.cu.cc b/tensorflow/core/kernels/matmul_op_gpu.cu.cc new file mode 100644 index 0000000000..17107ce5df --- /dev/null +++ b/tensorflow/core/kernels/matmul_op_gpu.cu.cc @@ -0,0 +1,32 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/matmul_op.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; + +// Partial specialization MatMulTensorFunctor +template +struct MatMulFunctor { + void operator()( + const GPUDevice& d, typename MatMulTypes::out_type out, + typename MatMulTypes::in_type in0, + typename MatMulTypes::in_type in1, + const Eigen::array, 1>& dim_pair) { + MatMul(d, To32Bit(out), To32Bit(in0), To32Bit(in1), dim_pair); + } +}; + +#define DEFINE(T) template struct MatMulFunctor; +DEFINE(float); +// DEFINE(double); // Does not compile 1/2015. +#undef DEFINE + +} // end namespace functor +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/matmul_op_test.cc b/tensorflow/core/kernels/matmul_op_test.cc new file mode 100644 index 0000000000..b2b8f3d905 --- /dev/null +++ b/tensorflow/core/kernels/matmul_op_test.cc @@ -0,0 +1,56 @@ +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include + +namespace tensorflow { + +static Graph* Matmul(int m, int k, int n, bool transpose_a, bool transpose_b) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor in0(DT_FLOAT, transpose_a ? TensorShape({k, m}) : TensorShape({m, k})); + in0.flat().setRandom(); + Tensor in1(DT_FLOAT, transpose_b ? TensorShape({n, k}) : TensorShape({k, n})); + in1.flat().setRandom(); + test::graph::Matmul(g, test::graph::Constant(g, in0), + test::graph::Constant(g, in1), transpose_a, transpose_b); + return g; +} + +#define BM_MatmulDev(M, K, N, TA, TB, DEVICE) \ + static void BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##DEVICE( \ + int iters) { \ + testing::ItemsProcessed(static_cast(iters) * M * K * N * 2); \ + test::Benchmark(#DEVICE, Matmul(M, K, N, TA, TB)).Run(iters); \ + } \ + BENCHMARK(BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##DEVICE); + +#define BM_Matmul(M, K, N, TA, TB) \ + BM_MatmulDev(M, K, N, TA, TB, cpu); \ + BM_MatmulDev(M, K, N, TA, TB, gpu); + +// Typical fully connected layers +BM_Matmul(8, 512, 512, false, false); +BM_Matmul(16, 512, 512, false, false); +BM_Matmul(128, 512, 512, false, false); + +BM_Matmul(8, 1024, 1024, false, false); +BM_Matmul(16, 1024, 1024, false, false); +BM_Matmul(128, 1024, 1024, false, false); +BM_Matmul(4096, 4096, 4096, false, false); + +// Backward for fully connected layers +BM_Matmul(8, 1024, 1024, false, true); +BM_Matmul(16, 1024, 1024, false, true); +BM_Matmul(128, 1024, 1024, false, true); + +// Forward softmax with large output size +BM_Matmul(8, 200, 10000, false, false); +BM_Matmul(20, 200, 10000, false, false); +BM_Matmul(20, 200, 20000, false, false); + +// Backward softmax with large output size +BM_Matmul(8, 10000, 200, false, true); +BM_Matmul(20, 10000, 200, false, true); +BM_Matmul(20, 20000, 200, false, true); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_inverse_op.cc b/tensorflow/core/kernels/matrix_inverse_op.cc new file mode 100644 index 0000000000..ad0948d6ef --- /dev/null +++ b/tensorflow/core/kernels/matrix_inverse_op.cc @@ -0,0 +1,64 @@ +// See docs in ../ops/linalg_ops.cc. +#include + +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/Eigen/LU" + +namespace tensorflow { + +template +class MatrixInverseOp + : public LinearAlgebraOp { + public: + explicit MatrixInverseOp(OpKernelConstruction* context) + : LinearAlgebraOp(context) {} + ~MatrixInverseOp() override {} + + TensorShape GetOutputMatrixShape( + const TensorShape& input_matrix_shape) override { + return input_matrix_shape; + } + + int64 GetCostPerUnit(const TensorShape& input_matrix_shape) override { + const int64 rows = input_matrix_shape.dim_size(0); + if (rows > (1LL << 20)) { + // A big number to cap the cost in case overflow. + return kint32max; + } else { + return rows * rows * rows; + } + } + + using typename LinearAlgebraOp::MatrixMap; + using + typename LinearAlgebraOp::ConstMatrixMap; + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& input, + MatrixMap* output) override { + OP_REQUIRES(context, input.rows() == input.cols(), + errors::InvalidArgument("Input matrix must be square.")); + if (input.rows() == 0) { + // By definition, an empty matrix's inverse is an emptry matrix. + return; + } + Eigen::FullPivLU> lu_decomposition(input); + OP_REQUIRES(context, lu_decomposition.isInvertible(), + errors::InvalidArgument("Input is not invertible.")); + *output = lu_decomposition.inverse(); + } +}; + +REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp), float); +REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp), double); +REGISTER_LINALG_OP("BatchMatrixInverse", (MatrixInverseOp), float); +REGISTER_LINALG_OP("BatchMatrixInverse", (MatrixInverseOp), + double); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc new file mode 100644 index 0000000000..31046018c5 --- /dev/null +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -0,0 +1,554 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/maxpooling_op.h" + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/pooling_ops_common.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/util/use_cudnn.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/core/kernels/maxpooling_op_gpu.h" +#include "tensorflow/core/kernels/pooling_ops_common_gpu.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +const int kInvalidMaxPoolingIndex = -1; + +template +struct SpatialMaxPoolWithArgMaxHelper { + static void Compute(Tensor* output, Tensor* output_arg_max, + const Tensor& tensor_in, const PoolParameters& params, + const Padding& padding) { + typedef Eigen::Map> + ConstEigenMatrixMap; + typedef Eigen::Map> + EigenMatrixMap; + typedef Eigen::Map> + EigenIndexMatrixMap; + + ConstEigenMatrixMap in_mat( + tensor_in.flat().data(), params.depth, + params.tensor_in_cols * params.tensor_in_rows * params.tensor_in_batch); + EigenMatrixMap out_mat( + output->flat().data(), params.depth, + params.out_width * params.out_height * params.tensor_in_batch); + EigenIndexMatrixMap out_arg_max_mat( + output_arg_max->flat().data(), params.depth, + params.out_width * params.out_height * params.tensor_in_batch); + + // Initializes the output tensor with MIN. + output_arg_max->flat().setConstant(kInvalidMaxPoolingIndex); + output->flat().setConstant(Eigen::NumTraits::lowest()); + + // The following code basically does the following: + // 1. Flattens the input and output tensors into two dimensional arrays. + // tensor_in_as_matrix: + // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch) + // output_as_matrix: + // depth by (out_width * out_height * tensor_in_batch) + // + // 2. Walks through the set of columns in the flattened tensor_in_as_matrix, + // and updates the corresponding column(s) in output_as_matrix with the + // max value. + for (int b = 0; b < params.tensor_in_batch; ++b) { + for (int h = 0; h < params.tensor_in_rows; ++h) { + for (int w = 0; w < params.tensor_in_cols; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + const int hpad = h + params.pad_rows; + const int wpad = w + params.pad_cols; + const int h_start = + (hpad < params.window_rows) + ? 0 + : (hpad - params.window_rows) / params.row_stride + 1; + const int h_end = + std::min(hpad / params.row_stride + 1, params.out_height); + const int w_start = + (wpad < params.window_cols) + ? 0 + : (wpad - params.window_cols) / params.col_stride + 1; + const int w_end = + std::min(wpad / params.col_stride + 1, params.out_width); + // compute elementwise max + const int in_index = + (b * params.tensor_in_rows + h) * params.tensor_in_cols + w; + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + const int out_index = + (b * params.out_height + ph) * params.out_width + pw; + /// NOTES(zhengxq): not using the eigen matrix operation for now. + /// May consider parallelizing the operations if needed. + for (int d = 0; d < params.depth; ++d) { + const T& input_ref = in_mat.coeffRef(d, in_index); + T& output_ref = out_mat.coeffRef(d, out_index); + int64& out_arg_max_ref = out_arg_max_mat.coeffRef(d, out_index); + if (output_ref < input_ref || + out_arg_max_ref == kInvalidMaxPoolingIndex) { + output_ref = input_ref; + int input_offset = in_index * params.depth + d; + out_arg_max_ref = input_offset; + } + } + } + } + } + } + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("MaxPool").Device(DEVICE_CPU), + MaxPoolingOp); + +#if GOOGLE_CUDA +// Forward declarations for the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void SpatialMaxPooling::operator()( \ + const Eigen::GpuDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input, int window_rows, \ + int window_cols, int row_stride, int col_stride, \ + const Eigen::PaddingType& padding); \ + extern template struct SpatialMaxPooling; + +DECLARE_GPU_SPEC(float); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Note(jiayq): Currently, the Caffe custom implementation is faster than the +// default Eigen implementation so we are using the custom kernel as the +// default. However, you can explicitly invoke the eigen version using +// kernel_label_map. +REGISTER_KERNEL_BUILDER(Name("MaxPool") + .Device(DEVICE_GPU) + .Label("eigen_tensor"), + MaxPoolingOp); +#endif // GOOGLE_CUDA + +// The operation to compute MaxPool gradients. +// It takes three inputs: +// - The original input tensor +// - The original output tensor +// - Backprop tensor for output +// It produces one output: backprop tensor for input. +template +class MaxPoolingGradOp : public OpKernel { + public: + explicit MaxPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument( + "Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument( + "Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + OP_REQUIRES( + context, ksize_[3] == 1 && stride_[3] == 1, + errors::Unimplemented( + "MaxPoolingGrad is not yet supported on the depth dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + const Tensor& tensor_out = context->input(1); + const Tensor& out_backprop = context->input(2); + + // For maxpooling, tensor_in should have 4 dimensions. + OP_REQUIRES(context, tensor_in.dims() == 4, + errors::InvalidArgument("tensor_in must be 4-dimensional")); + OP_REQUIRES(context, tensor_out.dims() == 4, + errors::InvalidArgument("tensor_out must be 4-dimensional")); + // For maxpooling, out_backprop should have 4 dimensions. + OP_REQUIRES(context, out_backprop.dims() == 4, + errors::InvalidArgument("out_backprop must be 4-dimensional")); + + TensorShape output_shape = tensor_in.shape(); + + // Tensor index_tensor(context->allocator(), DT_INT32, output_shape); + + Tensor tensor_out_dup; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::v(), + tensor_out.shape(), &tensor_out_dup)); + Tensor tensor_out_arg_max; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), + tensor_out.shape(), + &tensor_out_arg_max)); + + PoolParameters params{context, ksize_, stride_, padding_, + tensor_in.shape()}; + if (!context->status().ok()) { + return; + } + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + output->flat().setZero(); + + SpatialMaxPoolWithArgMaxHelper::Compute( + &tensor_out_dup, &tensor_out_arg_max, tensor_in, params, padding_); + auto out_backprop_flat = out_backprop.flat(); + auto input_backprop_flat = output->flat(); + auto out_arg_max_flat = tensor_out_arg_max.flat(); + int num_total_outputs = out_backprop.flat().size(); + int num_total_inputs = input_backprop_flat.size(); + + for (int index = 0; index < num_total_outputs; ++index) { + int input_backprop_index = out_arg_max_flat(index); + // Although this check is in the inner loop, it is worth its value + // so we don't end up with memory corruptions. Our benchmark shows that + // the performance impact is quite small + CHECK(input_backprop_index >= 0 && + input_backprop_index < num_total_inputs) + << "Invalid input backprop index: " << input_backprop_index << ", " + << num_total_inputs; + input_backprop_flat(input_backprop_index) += out_backprop_flat(index); + } + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; +}; + +REGISTER_KERNEL_BUILDER(Name("MaxPoolGrad").Device(DEVICE_CPU), + MaxPoolingGradOp); + +#ifdef GOOGLE_CUDA + +static void MaxPoolingBackwardCustomKernel( + OpKernelContext* context, const std::vector& size, + const std::vector& stride, Padding padding, const Tensor* tensor_in, + const Tensor& out_backprop, const TensorShape& tensor_in_shape) { + Tensor* output = nullptr; + + OP_REQUIRES_OK(context, + context->allocate_output(0, tensor_in_shape, &output)); + + PoolParameters params{context, size, stride, padding, tensor_in_shape}; + if (!context->status().ok()) { + return; + } + + MaxPoolBackwardNoMask( + tensor_in->flat().data(), params.tensor_in_batch, + params.tensor_in_rows, params.tensor_in_cols, params.depth, + params.out_height, params.out_width, params.window_rows, + params.window_cols, params.row_stride, params.col_stride, params.pad_rows, + params.pad_cols, out_backprop.flat().data(), + output->flat().data(), context->eigen_device()); +} + +template +class MaxPoolingGradOp : public OpKernel { + public: + typedef Eigen::GpuDevice Device; + + explicit MaxPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument( + "Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument( + "Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + + use_dnn_ = CanUseCudnn(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + const Tensor& tensor_out = context->input(1); + const Tensor& out_backprop = context->input(2); + + // For maxpooling, tensor_in should have 4 dimensions. + OP_REQUIRES(context, tensor_in.dims() == 4, + errors::InvalidArgument("tensor_in must be 4-dimensional 4")); + OP_REQUIRES(context, tensor_out.dims() == 4, + errors::InvalidArgument("tensor_out must be 4-dimensional")); + // For maxpooling, out_backprop should have 4 dimensions. + OP_REQUIRES(context, out_backprop.dims() == 4, + errors::InvalidArgument("out_backprop must be 4-dimensional")); + + TensorShape output_shape = tensor_in.shape(); + + if (use_dnn_) { + DnnPoolingGradOp::Compute( + context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize_, + stride_, padding_, &tensor_in, &tensor_out, out_backprop, + output_shape); + } else { + MaxPoolingBackwardCustomKernel(context, ksize_, stride_, padding_, + &tensor_in, out_backprop, output_shape); + } + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; + bool use_dnn_; +}; + +REGISTER_KERNEL_BUILDER(Name("MaxPoolGrad").Device(DEVICE_GPU), + MaxPoolingGradOp); + +#endif // GOOGLE_CUDA + +template +struct LaunchMaxPoolingNoMask; + +template +class MaxPoolingNoMaskOp : public OpKernel { + public: + explicit MaxPoolingNoMaskOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + + PoolParameters params{context, ksize_, stride_, padding_, + tensor_in.shape()}; + if (!context->status().ok()) { + return; + } + + TensorShape out_shape({params.tensor_in_batch, params.out_height, + params.out_width, params.depth}); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + LaunchMaxPoolingNoMask::launch(context, params, tensor_in, + output); + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; +}; + +template +struct LaunchMaxPoolingWithArgmax; + +template +class MaxPoolingWithArgmaxOp : public OpKernel { + public: + explicit MaxPoolingWithArgmaxOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument( + "Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument( + "Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + + PoolParameters params{context, ksize_, stride_, padding_, + tensor_in.shape()}; + if (!context->status().ok()) { + return; + } + + TensorShape out_shape({params.tensor_in_batch, params.out_height, + params.out_width, params.depth}); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + Tensor* argmax = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, out_shape, &argmax)); + + LaunchMaxPoolingWithArgmax::launch(context, params, tensor_in, + output, argmax); + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; +}; + +template +struct LaunchMaxPoolingGradWithArgmax; + +template +class MaxPoolingGradWithArgmaxOp : public OpKernel { + public: + explicit MaxPoolingGradWithArgmaxOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument( + "Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument( + "Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + const Tensor& grad_in = context->input(1); + const Tensor& argmax = context->input(2); + + PoolParameters params{context, ksize_, stride_, padding_, + tensor_in.shape()}; + if (!context->status().ok()) { + return; + } + + TensorShape out_shape({params.tensor_in_batch, params.tensor_in_rows, + params.tensor_in_cols, params.depth}); + Tensor* grad_out = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &grad_out)); + + LaunchMaxPoolingGradWithArgmax::launch(context, params, grad_in, + argmax, grad_out); + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; +}; + +#if GOOGLE_CUDA + +template +struct LaunchMaxPoolingNoMask { + static void launch(OpKernelContext* context, const PoolParameters& params, + const Tensor& input, Tensor* output) { + bool status = MaxPoolForwardWithOptionalArgmax( + input.flat().data(), params.tensor_in_batch, params.tensor_in_rows, + params.tensor_in_cols, params.depth, params.out_height, + params.out_width, params.window_rows, params.window_cols, + params.row_stride, params.col_stride, params.pad_rows, params.pad_cols, + output->flat().data(), nullptr, context->eigen_gpu_device()); + if (!status) { + context->SetStatus( + errors::Internal("Failed launching MaxPoolForwardNoMask")); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("MaxPool").Device(DEVICE_GPU), + MaxPoolingNoMaskOp); + +template +struct LaunchMaxPoolingWithArgmax { + static void launch(OpKernelContext* context, const PoolParameters& params, + const Tensor& input, Tensor* output, Tensor* argmax) { + bool status = MaxPoolForwardWithOptionalArgmax( + input.flat().data(), params.tensor_in_batch, params.tensor_in_rows, + params.tensor_in_cols, params.depth, params.out_height, + params.out_width, params.window_rows, params.window_cols, + params.row_stride, params.col_stride, params.pad_rows, params.pad_cols, + output->flat().data(), + reinterpret_cast(argmax->flat().data()), + context->eigen_gpu_device()); + if (!status) { + context->SetStatus( + errors::Internal("Failed launching MaxPoolForwardWithArgmax")); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") + .Device(DEVICE_GPU) + .TypeConstraint("Targmax"), + MaxPoolingWithArgmaxOp); + +template +struct LaunchMaxPoolingGradWithArgmax { + static void launch(OpKernelContext* context, const PoolParameters& params, + const Tensor& grad_in, const Tensor& argmax, + Tensor* grad_out) { + const int input_size = params.tensor_in_batch * params.tensor_in_rows * + params.tensor_in_cols * params.depth; + const int output_size = params.tensor_in_batch * params.out_height * + params.out_width * params.depth; + const int top_offset = params.out_height * params.out_width * params.depth; + const int bottom_offset = + params.tensor_in_rows * params.tensor_in_cols * params.depth; + bool status = MaxPoolBackwardWithArgmax( + output_size, input_size, grad_in.flat().data(), + reinterpret_cast(argmax.flat().data()), top_offset, + bottom_offset, grad_out->flat().data(), context->eigen_gpu_device()); + if (!status) { + context->SetStatus( + errors::Internal("Failed launching MaxPoolForwardWithArgmax")); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") + .Device(DEVICE_GPU) + .TypeConstraint("Targmax"), + MaxPoolingGradWithArgmaxOp); + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/maxpooling_op.h b/tensorflow/core/kernels/maxpooling_op.h new file mode 100644 index 0000000000..a074174118 --- /dev/null +++ b/tensorflow/core/kernels/maxpooling_op.h @@ -0,0 +1,29 @@ +#ifndef TENSORFLOW_KERNELS_MAXPOOLING_OP_H_ +#define TENSORFLOW_KERNELS_MAXPOOLING_OP_H_ +// Functor definition for MaxPoolingOp, must be compilable by nvcc. + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks" + +namespace tensorflow { +namespace functor { + +template +struct SpatialMaxPooling { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, int window_rows, + int window_cols, int row_stride, int col_stride, + const Eigen::PaddingType& padding) { + // Because we swap the layout, we swap the row/cols as well + output.swap_layout().device(d) = + Eigen::SpatialMaxPooling(input.swap_layout(), window_cols, window_rows, + col_stride, row_stride, padding); + } +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_MAXPOOLING_OP_H_ diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc new file mode 100644 index 0000000000..65262eb54e --- /dev/null +++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc @@ -0,0 +1,261 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/maxpooling_op.h" +#include "tensorflow/core/kernels/maxpooling_op_gpu.h" + +namespace tensorflow { +namespace { +// This is Yangqing's custom kernel for the maxpooling operation. There are +// three functions: MaxPoolForwardNCHW and MaxPoolForwardNHWC are the two +// forward functions, dealing with the forward case. MaxPoolBackward is the +// backward function that deals with the backward case for both storage orders. +// The parameters to the kernels in the forward function is as follows: +// nthreads: the number of threads, which is equal to the output size. +// bottom_data: the bottom data of N*H*W*C (or N*C*H*W) items. +// height, width, pooled_height, pooled_width: the input and output sizes. +// kernel_h, kernel_w: the kernel sizes. +// stride_h, stride_w: the strides. +// pad_t, pad_l: the padding values on the top and left side. +// top_data: the maxpool output. +// mask: the output mask of the same size as top_data. It is stored in +// int form, keeping track of the flattened index of the input item that +// produces the max output. If a nullptr is passed in for mask, no mask +// will be produced. +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); i += blockDim.x * gridDim.x) + +// To call the forward and backward functions, use e.g.: +// const int kThreadsPerBlock = 1024 +// const int output_size = batch * channels * pooled_height * pooled_width; +// MaxPoolForwardNCHW<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, +// kThreadsPerBlock, 0, cuda_stream>>>(...); +template +__global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data, + const int channels, const int height, + const int width, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, + const int pad_l, dtype* top_data, + int64* mask) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + dtype maxval = -FLT_MAX; + int maxidx = -1; + const dtype* bottom_data_n = bottom_data + n * channels * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int idx = c * height * width + h * width + w; + if (bottom_data_n[idx] > maxval) { + maxidx = idx; + maxval = bottom_data_n[idx]; + } + } + } + top_data[index] = maxval; + if (mask != nullptr) { + mask[index] = maxidx; + } + } +} + +template +__global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data, + const int height, const int width, + const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, + const int pad_l, dtype* top_data, + int64* mask) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int n = index; + int c = n % channels; + n /= channels; + int wstart = (n % pooled_width) * stride_w - pad_l; + n /= pooled_width; + int hstart = (n % pooled_height) * stride_h - pad_t; + n /= pooled_height; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + dtype maxval = -FLT_MAX; + int maxidx = -1; + const dtype* bottom_data_n = bottom_data + n * height * width * channels; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int idx = (h * width + w) * channels + c; + if (bottom_data_n[idx] > maxval) { + maxidx = idx; + maxval = bottom_data_n[idx]; + } + } + } + top_data[index] = maxval; + if (mask != nullptr) { + mask[index] = maxidx; + } + } +} + +template +__global__ void MaxPoolBackwardNoMaskNHWC( + const int nthreads, const dtype* bottom_data, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + const dtype* top_diff, dtype* bottom_diff) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // First find out the index to the maximum, since we have no mask. + int n = index; + int c = n % channels; + n /= channels; + int wstart = (n % pooled_width) * stride_w - pad_l; + n /= pooled_width; + int hstart = (n % pooled_height) * stride_h - pad_t; + n /= pooled_height; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + dtype maxval = -FLT_MAX; + int maxidx = -1; + const dtype* bottom_data_n = bottom_data + n * height * width * channels; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int idx = (h * width + w) * channels + c; + if (bottom_data_n[idx] > maxval) { + maxidx = idx; + maxval = bottom_data_n[idx]; + } + } + } + + // Atomically accumulate the bottom diff. The index could still be + // uninitialized, if all the bottom_data are NaN. + if (maxidx != -1) { + atomicAdd(bottom_diff + n * height * width * channels + maxidx, + top_diff[index]); + } + } +} + +// The parameters to the kernels in the backward function is as follows: +// nthreads: the number of threads, which is equal to the output size. +// top_diff: the gradient of the output data, of size N*Hout*Wout*C (or +// N*C*Hout*Wout). As we have stored the flattened index of the input +// entries, the backward function is agnostic of the input storage order. +// mask: the output mask of the same size as top_data. It is stored in +// int form, keeping track of the flattened index of the input item that +// produces the max output. +// top_offset: the pre-computed per-image offset of the maxpool output. This +// is equal to Hout*Wout*C. We choose to pre-compute this so we do not +// need to compute it every time inside the kernel. +// bottom_offset: the pre-computed per-image offset of the maxpool input. +// This is equal to H*W*C. +// bottom_diff: the gradient with respect to the input. +// This function relies on atomicAdd to avoid race conditions. Also, before the +// kernel is run, you will need to make sure that bottom_diff is filled with +// zero first. +template +__global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff, + const int64* mask, const int top_offset, + const int bottom_offset, dtype* bottom_diff) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int image_id = (index / top_offset); + atomicAdd(bottom_diff + image_id * bottom_offset + mask[index], + top_diff[index]); + } +} + +template +__global__ void SetZero(const int nthreads, dtype* bottom_diff) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = dtype(0); } +} + +#undef CUDA_1D_KERNEL_LOOP +} // namespace + +bool MaxPoolForwardWithOptionalArgmax( + const float* bottom_data, const int batch, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + float* top_data, int64* mask, const Eigen::GpuDevice& d) { + const int kThreadsPerBlock = 1024; + const int output_size = batch * channels * pooled_height * pooled_width; + + MaxPoolForwardNHWC<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>( + output_size, bottom_data, height, width, channels, pooled_height, + pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, + top_data, mask); + return d.ok(); +} + +bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch, + const int height, const int width, + const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, const int pad_l, + const float* top_diff, float* bottom_diff, + const Eigen::GpuDevice& d) { + const int kThreadsPerBlock = 1024; + const int bottom_size = batch * channels * height * width; + const int top_size = batch * channels * pooled_height * pooled_width; + + SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff); + + MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) / + kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>( + top_size, bottom_data, height, width, channels, pooled_height, + pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, + top_diff, bottom_diff); + return d.ok(); +} + +bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, + const float* top_diff, const int64* mask, + const int top_offset, const int bottom_offset, + float* bottom_diff, const Eigen::GpuDevice& d) { + const int kThreadsPerBlock = 1024; + SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff); + MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>( + output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff); + return d.ok(); +} + +typedef Eigen::GpuDevice GPUDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::SpatialMaxPooling; + +DEFINE_GPU_KERNELS(float) + +#undef DEFINE_GPU_KERNELS + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.h b/tensorflow/core/kernels/maxpooling_op_gpu.h new file mode 100644 index 0000000000..bfdac904cc --- /dev/null +++ b/tensorflow/core/kernels/maxpooling_op_gpu.h @@ -0,0 +1,42 @@ +#if !GOOGLE_CUDA +#error This file must only be included when building with Cuda support +#endif + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_ + +#define EIGEN_USE_GPU + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks" + +namespace tensorflow { + +// Run the forward pass of max pooling, optionally writing the argmax indices to +// the mask array, if it is not nullptr. If mask is passed in as nullptr, the +// argmax indices are not written. +bool MaxPoolForwardWithOptionalArgmax( + const float* bottom_data, const int batch, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + float* top_data, int64* mask, const Eigen::GpuDevice& d); + +bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, + const float* top_diff, const int64* mask, + const int top_offset, const int bottom_offset, + float* bottom_diff, const Eigen::GpuDevice& d); + +bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch, + const int height, const int width, + const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, const int pad_l, + const float* top_diff, float* bottom_diff, + const Eigen::GpuDevice& d); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_ diff --git a/tensorflow/core/kernels/no_op.cc b/tensorflow/core/kernels/no_op.cc new file mode 100644 index 0000000000..b4f9df81a6 --- /dev/null +++ b/tensorflow/core/kernels/no_op.cc @@ -0,0 +1,8 @@ +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { + +REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_CPU), NoOp); +REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_GPU), NoOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/no_op.h b/tensorflow/core/kernels/no_op.h new file mode 100644 index 0000000000..a3bcbd7680 --- /dev/null +++ b/tensorflow/core/kernels/no_op.h @@ -0,0 +1,17 @@ +#ifndef TENSORFLOW_KERNELS_NO_OP_H_ +#define TENSORFLOW_KERNELS_NO_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class NoOp : public OpKernel { + public: + explicit NoOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} + bool IsExpensive() override { return false; } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_NO_OP_H_ diff --git a/tensorflow/core/kernels/ops_testutil.cc b/tensorflow/core/kernels/ops_testutil.cc new file mode 100644 index 0000000000..7bea17b9e2 --- /dev/null +++ b/tensorflow/core/kernels/ops_testutil.cc @@ -0,0 +1,18 @@ +#include "tensorflow/core/kernels/ops_testutil.h" + +namespace tensorflow { +namespace test { + +NodeDef Node(const string& name, const string& op, + const std::vector& inputs) { + NodeDef def; + def.set_name(name); + def.set_op(op); + for (const string& s : inputs) { + def.add_input(s); + } + return def; +} + +} // namespace test +} // namespace tensorflow diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h new file mode 100644 index 0000000000..7a3405bf04 --- /dev/null +++ b/tensorflow/core/kernels/ops_testutil.h @@ -0,0 +1,191 @@ +#ifndef TENSORFLOW_KERNELS_OPS_TESTUTIL_H_ +#define TENSORFLOW_KERNELS_OPS_TESTUTIL_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" +#include + +namespace tensorflow { + +namespace test { + +// Return a NodeDef with the specified name/op/inputs. +NodeDef Node(const string& name, const string& op, + const std::vector& inputs); + +} // namespace test + +// Helpful functions to test operators. +// +// This class will eventually be replaced / heavily modified +// to use the BrainClient interface. +class OpsTestBase : public ::testing::Test { + public: + OpsTestBase() : device_type_(DEVICE_CPU) { + device_.reset( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + CHECK(device_.get()) << "Could not create CPU device"; + } + + ~OpsTestBase() override { + gtl::STLDeleteElements(&tensors_); + context_.reset(nullptr); + } + + void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); } + + // Clients can manipulate the underlying NodeDef via this accessor. + NodeDef* node_def() { return &node_def_; } + + // Initializes an operator that takes in 'input_types' as input + // and output types as output. + // + // Returns the status of initialization. + Status InitOp() { + Status status; + kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(), + node_def_, &status); + if (kernel_ != nullptr) input_types_ = kernel_->input_types(); + return status; + } + + // Adds an input for every element described by the shape. + // 'input_mapping' maps an index (0...NumElements(shape)) to a + // value. + // + // TODO(vrv): Replace with something like a BrainClient Feed. + template + void AddInput(const TensorShape& shape, std::function input_mapping) { + CHECK_GT(input_types_.size(), inputs_.size()) + << "Adding more inputs than types; perhaps you need to call MakeOp"; + bool is_ref = IsRefType(input_types_[inputs_.size()]); + Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()), + DataTypeToEnum::v(), shape); + test::FillFn(input, input_mapping); + tensors_.push_back(input); + if (is_ref) { + CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), + DataTypeToEnum::v()); + inputs_.push_back({&lock_for_refs_, input}); + } else { + CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum::v()); + inputs_.push_back({nullptr, input}); + } + } + + // Like AddInput but takes in an explicit arrayslice of data. + template + void AddInputFromArray(const TensorShape& shape, + const gtl::ArraySlice& data) { + CHECK_GT(input_types_.size(), inputs_.size()) + << "Adding more inputs than types; perhaps you need to call MakeOp"; + bool is_ref = IsRefType(input_types_[inputs_.size()]); + Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()), + DataTypeToEnum::v(), shape); + test::FillValues(input, data); + tensors_.push_back(input); + if (is_ref) { + CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), + DataTypeToEnum::v()); + inputs_.push_back({&lock_for_refs_, input}); + } else { + CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum::v()); + inputs_.push_back({nullptr, input}); + } + } + + // Runs an operation producing 'num_outputs' outputs. + // + // Returns the context's status after running the operation. + Status RunOpKernel() { + OpKernelContext::Params params; + params.device = device_.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs_; + params.op_kernel = kernel_.get(); + params.output_alloc_attr = [this, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = + (kernel_->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; + params.slice_reader_cache = &slice_reader_cache_wrapper; + + context_.reset(new OpKernelContext(params)); + device_->Compute(kernel_.get(), context_.get()); + return context_->status(); + } + + // Returns the tensor input for 'input_index'. + // + // REQUIRES: 0 <= input_index < context_->num_inputs() + const Tensor& GetInput(int input_index) const { + CHECK_LT(input_index, context_->num_inputs()); + CHECK(!IsRefType(context_->input_dtype(input_index))); + return context_->input(input_index); + } + + TensorValue mutable_input(int input_index) { + CHECK_LT(input_index, inputs_.size()); + return inputs_[input_index]; + } + // Returns the tensor output for 'output_index'. + // + // REQUIRES: 0 <= output_index < context_->num_outputs() + Tensor* GetOutput(int output_index) { + CHECK_LT(output_index, context_->num_outputs()); + return context_->mutable_output(output_index); + } + + Allocator* allocator() { + return device_->GetAllocator(AllocatorAttributes()); + } + + const DataTypeVector& output_types() const { return kernel_->output_types(); } + + protected: + std::unique_ptr device_; + + std::unique_ptr kernel_; + NodeDef node_def_; + DataTypeVector input_types_; + DeviceType device_type_; + + mutex lock_for_refs_; // Used as the Mutex for inputs added as refs + + gtl::InlinedVector inputs_; + // Owns Tensors. + std::vector tensors_; + + std::unique_ptr context_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_OPS_TESTUTIL_H_ diff --git a/tensorflow/core/kernels/ops_util.cc b/tensorflow/core/kernels/ops_util.cc new file mode 100644 index 0000000000..ca2925128e --- /dev/null +++ b/tensorflow/core/kernels/ops_util.cc @@ -0,0 +1,113 @@ +#include + +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/padding.h" + +namespace tensorflow { + +void RequireDefaultOps() { +// TODO(opensource): Use a more generic sounding preprocessor name than +// GOOGLE_CUDA (maybe SUPPORT_CUDA?) +#if GOOGLE_CUDA + void RequireGPUDevice(); + RequireGPUDevice(); +#endif +} + +Status Get2dOutputSize(const int in_height, const int in_width, + int filter_height, int filter_width, int row_stride, + int col_stride, Padding padding, int* new_height, + int* new_width, int* pad_rows, int* pad_cols) { + int pad_bottom_unused, pad_right_unused; + return Get2dOutputSizeVerbose( + in_height, in_width, filter_height, filter_width, row_stride, col_stride, + padding, new_height, new_width, pad_rows, &pad_bottom_unused, pad_cols, + &pad_right_unused); +} + +Status Get2dOutputSizeVerbose(const int in_height, const int in_width, + int filter_height, int filter_width, + int row_stride, int col_stride, Padding padding, + int* new_height, int* new_width, int* pad_top, + int* pad_bottom, int* pad_left, int* pad_right) { + // Cannot have strides larger than the patch size. + if (row_stride > filter_height || col_stride > filter_width) { + return errors::InvalidArgument( + "stride must be less than or equal to kernel size"); + } + switch (padding) { + case Padding::VALID: + *new_height = ceil((in_height - filter_height + 1.f) / + static_cast(row_stride)); + *new_width = ceil((in_width - filter_width + 1.f) / + static_cast(col_stride)); + *pad_top = 0; + *pad_bottom = 0; + *pad_left = 0; + *pad_right = 0; + break; + case Padding::SAME: + *new_height = ceil(in_height / static_cast(row_stride)); + *new_width = ceil(in_width / static_cast(col_stride)); + // Calculate padding for top/bottom/left/right, spilling any excess + // padding to bottom and right. + const int pad_needed_height = + (*new_height - 1) * row_stride + filter_height - in_height; + *pad_top = pad_needed_height / 2; + CHECK_GE(pad_needed_height, 0); + *pad_bottom = pad_needed_height - *pad_top; + + const int pad_needed_width = + (*new_width - 1) * col_stride + filter_width - in_width; + *pad_left = pad_needed_width / 2; + CHECK_GE(pad_needed_width, 0); + *pad_right = pad_needed_width - *pad_left; + break; + } + if (*new_height < 0 || *new_width < 0) { + return errors::InvalidArgument("computed output size would be negative"); + } + return Status::OK(); +} + +Eigen::PaddingType BrainPadding2EigenPadding(Padding padding) { + switch (padding) { + case Padding::VALID: + return Eigen::PADDING_VALID; + case Padding::SAME: + return Eigen::PADDING_SAME; + } + return Eigen::PADDING_SAME; // Prevent compiler warning about missing return +} + +Status GetBroadcastSize(const int index, const int in_size, + const int ksize, const int stride, + const int pad_size, int* bindex, int* bsize) { + // Cannot have strides larger than the patch size. + if (stride > ksize) { + return errors::InvalidArgument( + "stride must be less than or equal to kernel size"); + } + // Cannot have index beyond the input size. + if (index * stride > in_size) { + return errors::InvalidArgument( + "index * stride must be less than or equal to input size"); + } + *bindex = index * stride; + *bsize = ksize; + if (*bindex < pad_size) { + // If the current index is in the padding area, start broadcast from index + // 0 with broadcast size reduced by padding size. + *bsize = ksize + *bindex - pad_size; + *bindex = 0; + } else { + // Otherwise, start broadcast from current index reduced by padding size. + *bindex -= pad_size; + } + if (*bindex + ksize > in_size) { + *bsize = std::min((in_size - *bindex), ksize); + } + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/core/kernels/ops_util.h b/tensorflow/core/kernels/ops_util.h new file mode 100644 index 0000000000..283338f8df --- /dev/null +++ b/tensorflow/core/kernels/ops_util.h @@ -0,0 +1,180 @@ +#ifndef TENSORFLOW_KERNELS_OPS_UTIL_H_ +#define TENSORFLOW_KERNELS_OPS_UTIL_H_ + +// This file contains utilities for various operations. + +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +// Call this function from a test if op kernels are not being +// registered. This can happen if the test is linked in a shared +// mode and has no direct references to any code from this directory. +void RequireDefaultOps(); + +// Get2dOutputSize(): Given an input tensor, kernel, stride and padding +// type, the function computes the output and padding dimensions. +// +// Convolution layers take in an input tensor of shape (D, C, R, B), and +// convolve it with a set of filters, which can also be presented as a +// tensor (D, K, K, M), where M is the number of filters, K is the filter size, +// and each 3-dimensional tensor of size (D, K, K) is a filter. For +// simplicity we assume that we always use square filters (which is usually the +// case in images). It also takes in a few additional parameters: +// +// Stride (S): the stride with which we apply the filters. This is the offset +// between locations where we apply the filters. A larger stride +// means that the output will be spatially smaller. +// +// Padding (P): the padding we apply to the input tensor along the R and C +// dimensions. This is usually used to make sure that the spatial dimension +// do not shrink when we progress with convolutions. Two types of padding are +// often used: +// SAME: the pad value is computed so that the output will have size R/S +// and C/S. +// VALID: no padding is carried out. +// The padded area is zero-filled. +// +// The output dimensions for convolution and many other operations, when given +// all the parameters above, are as follows: +// - When Padding = SAME: the output size is (B, R', C', M), where +// R' = ceil(float(R) / float(S)) +// C' = ceil(float(C) / float(S)) +// where ceil is the ceiling function. The number of padded rows and columns +// are computed as: +// Pr = ((R' - 1) * S + K - R) / 2 +// Pc = ((C' - 1) * S + K - C) / 2 +// When the stride is 1, we have the simplified case +// R'=R, C'=C, Pr=Pc=(K-1)/2. +// This is where SAME comes from - the output has the same size as the input +// has. +// +// - When Padding = VALID: the output size is computed as +// R' = ceil(float(R - K + 1) / float(S)) +// C' = ceil(float(C - K + 1) / float(S)) +// and the number of padded rows and columns are computed in the same way. +// When the stride is 1, we have the simplified case +// R'=R-K+1, C'=C-K+1, Pr=0, Pc=0. +// +// For convolution, mathematically, the output value at location (b, r', c', m) +// is the inner product of two vectors: the chunk of input at +// (b, (r'*S-Pr) : (r'*S-Pr+K), (c'*S-Pc) : (c'*S-Pc+K), :), +// and the filter at (m, :, :, :). +// +Status Get2dOutputSize(const int in_height, const int in_width, + int filter_height, int filter_width, int row_stride, + int col_stride, Padding padding, int* new_height, + int* new_width, int* pad_rows, int* pad_cols); + +// Returns the same output dimensions as in Get2dOutputSize, but returns verbose +// padding dimensions (top/bottom/left/right). Any excess padding (caused by +// an odd padding size value) is added to the 'pad_bottom' and 'pad_right' +// dimensions. +Status Get2dOutputSizeVerbose(const int in_height, const int in_width, + int filter_height, int filter_width, + int row_stride, int col_stride, Padding padding, + int* new_height, int* new_width, int* pad_top, + int* pad_bottom, int* pad_left, int* pad_right); + +// Calculates broadcast starting index and size. For SAME padding, addition +// padding could be applied to right, left, top and bottom. Depending on the +// current index, input size, kernel size, stride, padding size, the starting +// index and size for broadcast for that dimension are different from the +// current index and kernel size. +// This is mainly used by gradient algorithms for pooling operations. +Status GetBroadcastSize(const int index, const int in_size, + const int ksize, const int stride, + const int pad_size, int* bindex, int* bsize); + +// Converts Brain's Padding to Eigen's PaddingType. +Eigen::PaddingType BrainPadding2EigenPadding(Padding padding); + +// Given a shape 's' of a tensor of type T. Returns true iff the +// number of bytes occupied by each dim 0 (i.e., &tensor(i + 1, ...) - +// &tensor(i, ...)) is multiple of EIGEN_ALIGN_BYTES. +template +bool IsInnerDimsSizeAligned(const TensorShape& s) { + if (s.dims() == 0) return false; + const int64 dim0_size = s.dim_size(0); + if (dim0_size == 0) return false; + const int64 bytes_per_dim0 = (s.num_elements() / dim0_size) * sizeof(T); + return bytes_per_dim0 % EIGEN_MAX_ALIGN_BYTES == 0; +} + +// Returns in 'col_data', image patches in storage order (height, width, depth) +// extracted from image at 'input_data', which is requred to be in storage +// order (batch, height, width, depth). +// Implementation written by Yangqing Jia (jiayq). +template +void Im2col(const T* input_data, const int depth, const int height, + const int width, const int filter_h, const int filter_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, const int stride_w, T* col_data) { + int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; + int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; + + int h_pad = -pad_t; + for (int h = 0; h < height_col; ++h) { + int w_pad = -pad_l; + for (int w = 0; w < width_col; ++w) { + for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { + for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { + if (ih >= 0 && ih < height && iw >= 0 && iw < width) { + memcpy(col_data, input_data + (ih * width + iw) * depth, + sizeof(T) * depth); + } else { + // This should be simply padded with zero. + memset(col_data, 0, sizeof(T) * depth); + } + col_data += depth; + } + } + w_pad += stride_w; + } + h_pad += stride_h; + } +} + +// Returns in 'im_data' image patch in storage order (height, width, depth), +// constructed from patches in 'col_data', which is required to be in storage +// order (out_height * out_width, filter_height, filter_width, in_depth). +// Implementation by Yangqing Jia (jiayq). +template +void Col2im(const T* col_data, const int depth, const int height, + const int width, const int filter_h, const int filter_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, const int stride_w, T* im_data) { + memset(im_data, 0, sizeof(T) * height * width * depth); + int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; + int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; + int h_pad = -pad_t; + for (int h = 0; h < height_col; ++h) { + int w_pad = -pad_l; + for (int w = 0; w < width_col; ++w) { + T* im_patch_data = im_data + (h_pad * width + w_pad) * depth; + for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { + for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { + if (ih >= 0 && ih < height && iw >= 0 && iw < width) { + // TODO(andydavis) Vectorize this loop (if compiler does not). + for (int i = 0; i < depth; ++i) { + im_patch_data[i] += col_data[i]; + } + } + im_patch_data += depth; + col_data += depth; + } + // Jump over remaining number of depth. + im_patch_data += depth * (width - filter_w); + } + w_pad += stride_w; + } + h_pad += stride_h; + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_OPS_UTIL_H_ diff --git a/tensorflow/core/kernels/ops_util_test.cc b/tensorflow/core/kernels/ops_util_test.cc new file mode 100644 index 0000000000..bc4f57e220 --- /dev/null +++ b/tensorflow/core/kernels/ops_util_test.cc @@ -0,0 +1,265 @@ +#include "tensorflow/core/kernels/ops_util.h" +#include + +namespace tensorflow { +namespace { + +class OpsUtilTest : public ::testing::Test { + protected: + OpsUtilTest() {} + ~OpsUtilTest() override {} + + // Padding structure. + struct padding_struct { + // Input parameters. + struct { + int in_height; + int in_width; + int filter_height; + int filter_width; + int row_stride; + int col_stride; + Padding padding; + } input; + // Output. + struct { + int new_height; + int new_width; + int pad_top; + int pad_bottom; + int pad_left; + int pad_right; + } output; + }; + + // Broadcast structure. + struct bcast_struct { + // Input parameters. + struct { + int index; // Current index. + int in_size; // Size of the dimension. + int ksize; // Kernel size. + int stride; // Stride. + int pad_size; // Padding size. + } input; + // Output. + struct { + int new_index; // New starting index. + int new_size; // New broadcast size. + } output; + }; + + static void VerifyGet2dOutputSizeBoundaries(padding_struct pad_struct, + error::Code code) { + int new_height, new_width, pad_rows, pad_cols; + Status status = Get2dOutputSize( + pad_struct.input.in_height, pad_struct.input.in_width, + pad_struct.input.filter_height, pad_struct.input.filter_width, + pad_struct.input.row_stride, pad_struct.input.col_stride, + pad_struct.input.padding, &new_height, &new_width, &pad_rows, + &pad_cols); + EXPECT_EQ(status.code(), code) << status; + } + + static void VerifyGet2dOutputSizeValues(padding_struct pad_struct, + error::Code code) { + int new_height, new_width, pad_rows, pad_cols; + Status status = Get2dOutputSize( + pad_struct.input.in_height, pad_struct.input.in_width, + pad_struct.input.filter_height, pad_struct.input.filter_width, + pad_struct.input.row_stride, pad_struct.input.col_stride, + pad_struct.input.padding, &new_height, &new_width, &pad_rows, + &pad_cols); + EXPECT_EQ(status.code(), code) << status; + EXPECT_EQ(pad_struct.output.new_height, new_height); + EXPECT_EQ(pad_struct.output.new_width, new_width); + EXPECT_EQ(pad_struct.output.pad_top, pad_rows); + EXPECT_EQ(pad_struct.output.pad_left, pad_cols); + } + + static void VerifyGet2dOutputVerboseSizeValues(padding_struct pad_struct, + error::Code code) { + int new_height, new_width, pad_top, pad_bottom, pad_left, pad_right; + Status status = Get2dOutputSizeVerbose( + pad_struct.input.in_height, pad_struct.input.in_width, + pad_struct.input.filter_height, pad_struct.input.filter_width, + pad_struct.input.row_stride, pad_struct.input.col_stride, + pad_struct.input.padding, &new_height, &new_width, &pad_top, + &pad_bottom, &pad_left, &pad_right); + EXPECT_EQ(status.code(), code) << status; + EXPECT_EQ(pad_struct.output.new_height, new_height); + EXPECT_EQ(pad_struct.output.new_width, new_width); + EXPECT_EQ(pad_struct.output.pad_top, pad_top); + EXPECT_EQ(pad_struct.output.pad_bottom, pad_bottom); + EXPECT_EQ(pad_struct.output.pad_left, pad_left); + EXPECT_EQ(pad_struct.output.pad_right, pad_right); + } + + static void VerifyBoundaries(bcast_struct bcast, error::Code code) { + int new_index, new_size; + Status status = GetBroadcastSize( + bcast.input.index, bcast.input.in_size, bcast.input.ksize, + bcast.input.stride, bcast.input.pad_size, &new_index, &new_size); + EXPECT_EQ(status.code(), code) << status; + } + + static void VerifyBcastValues(bcast_struct bcast) { + int new_index, new_size; + EXPECT_EQ(Status::OK(), + GetBroadcastSize(bcast.input.index, bcast.input.in_size, + bcast.input.ksize, bcast.input.stride, + bcast.input.pad_size, &new_index, &new_size)); + EXPECT_EQ(bcast.output.new_index, new_index); + EXPECT_EQ(bcast.output.new_size, new_size); + } +}; + +// Test stride > ksize fails with INVALID_ARGUMENT. +TEST_F(OpsUtilTest, Get2dOutputSizeInvalidTest) { + padding_struct pad_struct = {{3, 3, 1, 2, 2, 2, SAME}, {3, 3, 1, 1, 1, 1}}; + VerifyGet2dOutputSizeBoundaries(pad_struct, error::INVALID_ARGUMENT); +} + +TEST_F(OpsUtilTest, Get2dOutputSizeNegativeSizeTest) { + padding_struct pad_struct = {{1, 1, 3, 3, 1, 1, VALID}, {-1, -1, 0, 0, 0, 0}}; + VerifyGet2dOutputSizeBoundaries(pad_struct, error::INVALID_ARGUMENT); +} + +TEST_F(OpsUtilTest, Get2dOutputSizeSquareFilterTest) { + padding_struct pad_struct1 = {{3, 3, 2, 2, 2, 2, SAME}, {2, 2, 0, 0, 0, 0}}; + padding_struct pad_struct2 = {{3, 3, 2, 2, 2, 2, VALID}, {1, 1, 0, 0, 0, 0}}; + VerifyGet2dOutputSizeValues(pad_struct1, error::OK); + VerifyGet2dOutputSizeValues(pad_struct2, error::OK); +} + +TEST_F(OpsUtilTest, Get2dOutputSizeNonSquareFilterTest) { + padding_struct pad_struct1 = {{4, 5, 1, 2, 1, 1, SAME}, {4, 5, 0, 0, 0, 0}}; + padding_struct pad_struct2 = {{4, 5, 1, 2, 1, 1, VALID}, {4, 4, 0, 0, 0, 0}}; + VerifyGet2dOutputSizeValues(pad_struct1, error::OK); + VerifyGet2dOutputSizeValues(pad_struct2, error::OK); +} + +TEST_F(OpsUtilTest, Get2dOutputSizeUnevenStrideTest) { + padding_struct pad_struct1 = {{4, 4, 2, 2, 1, 2, VALID}, {3, 2, 0, 0, 0, 0}}; + padding_struct pad_struct2 = {{4, 4, 2, 2, 2, 1, VALID}, {2, 3, 0, 0, 0, 0}}; + VerifyGet2dOutputSizeValues(pad_struct1, error::OK); + VerifyGet2dOutputSizeValues(pad_struct2, error::OK); +} + +TEST_F(OpsUtilTest, Get2dOutputSizeVerbose) { + padding_struct pad_struct1 = {{3, 3, 2, 2, 2, 2, SAME}, {2, 2, 0, 1, 0, 1}}; + padding_struct pad_struct2 = {{3, 3, 2, 2, 2, 2, VALID}, {1, 1, 0, 0, 0, 0}}; + VerifyGet2dOutputVerboseSizeValues(pad_struct1, error::OK); + VerifyGet2dOutputVerboseSizeValues(pad_struct2, error::OK); +} + +// Test stride > ksize fails with INVALID_ARGUMENT. +TEST_F(OpsUtilTest, GetBroadcastTest3_1_2_0) { + bcast_struct bcast = {{0, 3, 1, 2, 0}, {0, 3}}; + VerifyBoundaries(bcast, error::INVALID_ARGUMENT); +} + +// Test index * stride > in_size fails with INVALID_ARGUMENT. +TEST_F(OpsUtilTest, GetBroadcastTestBadIndex) { + bcast_struct bcast = {{2, 3, 1, 2, 0}, {0, 3}}; + VerifyBoundaries(bcast, error::INVALID_ARGUMENT); +} + +// in_size = 3, ksize = 3, stride = 1, pad_size = 0 +TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_0) { + bcast_struct bcast[] = { + {{0, 3, 3, 1, 0}, {0, 3}}, + {{1, 3, 3, 1, 0}, {1, 2}}, + {{2, 3, 3, 1, 0}, {2, 1}}, + }; + for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { + VerifyBcastValues(bcast[i]); + } +} + +// in_size = 3, ksize = 3, stride = 1, pad_size = 1 +TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_1) { + bcast_struct bcast[] = { + {{0, 3, 3, 1, 1}, {0, 2}}, + {{1, 3, 3, 1, 1}, {0, 3}}, + {{2, 3, 3, 1, 1}, {1, 2}}, + }; + for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { + VerifyBcastValues(bcast[i]); + } +} + +// in_size = 3, ksize = 3, stride = 1, pad_size = 2 +TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_2) { + bcast_struct bcast[] = { + {{0, 3, 3, 1, 2}, {0, 1}}, + {{1, 3, 3, 1, 2}, {0, 2}}, + {{2, 3, 3, 1, 2}, {0, 3}}, + }; + for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { + VerifyBcastValues(bcast[i]); + } +} + +// in_size = 3, ksize = 3, stride = 2, pad_size = 0 +TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_0) { + bcast_struct bcast[] = { + {{0, 3, 3, 2, 0}, {0, 3}}, {{1, 3, 3, 2, 0}, {2, 1}}, + }; + for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { + VerifyBcastValues(bcast[i]); + } +} + +// in_size = 3, ksize = 3, stride = 2, pad_size = 1 +TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_1) { + bcast_struct bcast[] = { + {{0, 3, 3, 2, 1}, {0, 2}}, {{1, 3, 3, 2, 1}, {1, 2}}, + }; + for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { + VerifyBcastValues(bcast[i]); + } +} + +// in_size = 3, ksize = 3, stride = 2, pad_size = 2 +TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_2) { + bcast_struct bcast[] = { + {{0, 3, 3, 2, 2}, {0, 1}}, + }; + for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { + VerifyBcastValues(bcast[i]); + } +} + +// in_size = 3, ksize = 3, stride = 3, pad_size = 0 +TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_0) { + bcast_struct bcast[] = { + {{0, 3, 3, 3, 0}, {0, 3}}, + }; + for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { + VerifyBcastValues(bcast[i]); + } +} + +// in_size = 3, ksize = 3, stride = 3, pad_size = 1 +TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_1) { + bcast_struct bcast[] = { + {{0, 3, 3, 3, 1}, {0, 2}}, {{1, 3, 3, 3, 1}, {2, 1}}, + }; + for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { + VerifyBcastValues(bcast[i]); + } +} + +// in_size = 3, ksize = 3, stride = 3, pad_size = 2 +TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_2) { + bcast_struct bcast[] = { + {{0, 3, 3, 3, 2}, {0, 1}}, + }; + for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { + VerifyBcastValues(bcast[i]); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc new file mode 100644 index 0000000000..cb125ea2fe --- /dev/null +++ b/tensorflow/core/kernels/pack_op.cc @@ -0,0 +1,114 @@ +// See docs in ../ops/array_ops.cc. + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/concat_op.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// -------------------------------------------------------------------------- +template +class PackOp : public OpKernel { + public: + typedef std::vector::ConstMatrix>> + ConstMatrixVector; + + explicit PackOp(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* c) override { + OpInputList values; + OP_REQUIRES_OK(c, c->input_list("values", &values)); + const int num = values.size(); + + // Verify that all input shapes match + for (int i = 1; i < num; i++) { + OP_REQUIRES(c, values[0].shape().IsSameSize(values[i].shape()), + errors::InvalidArgument( + "Shapes of all inputs must match: values[0].shape = ", + values[0].shape().ShortDebugString(), " != values[", i, + "].shape = ", values[i].shape().ShortDebugString())); + } + + TensorShape output_shape(values[0].shape()); + output_shape.InsertDim(0, num); + + // In the num = 1 case, just reshape the input + if (num == 1) { + Tensor output; + CHECK(output.CopyFrom(values[0], output_shape)); + c->set_output(0, output); + return; + } + + // Allocate output + Tensor* output; + OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); + + const int output_size = output->NumElements(); + if (output_size > 0) { + auto output_flat = output->shaped({1, output_size}); + + // Except for shapes, pack is a special case of concat, so we reuse the + // same computational kernels. + ConstMatrixVector inputs_flat; + inputs_flat.reserve(num); + for (int i = 0; i < num; ++i) { + inputs_flat.emplace_back(new typename TTypes::ConstMatrix( + values[i].shaped({1, values[i].NumElements()}))); + } + if (std::is_same::value) { + ConcatGPU(c->eigen_gpu_device(), inputs_flat, &output_flat); + } else { + ConcatCPU(c->device(), inputs_flat, &output_flat); + } + } + } +}; + +#define REGISTER_PACK(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Pack").Device(DEVICE_CPU).TypeConstraint("T"), \ + PackOp) + +TF_CALL_ALL_TYPES(REGISTER_PACK); +REGISTER_PACK(quint8); +REGISTER_PACK(qint8); +REGISTER_PACK(qint32); +REGISTER_PACK(bfloat16); + +#undef REGISTER_PACK + +#if GOOGLE_CUDA + +#define REGISTER_GPU(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Pack").Device(DEVICE_GPU).TypeConstraint("T"), \ + PackOp) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); +#undef REGISTER_GPU + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Pack") + .Device(DEVICE_GPU) + .HostMemory("values") + .HostMemory("output") + .TypeConstraint("T"), + PackOp); + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc new file mode 100644 index 0000000000..6c66e54e3d --- /dev/null +++ b/tensorflow/core/kernels/pad_op.cc @@ -0,0 +1,159 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/pad_op.h" + +#include +#include +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class PadOp : public OpKernel { + public: + explicit PadOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& in0 = context->input(0); + const Tensor& in1 = context->input(1); + const int dims = in0.dims(); + static const int kMinDims = 0; + static const int kMaxDims = 5; + OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims, + errors::Unimplemented("inputs rank not in [", kMinDims, ",", + kMaxDims, "]: ", dims)); + OP_REQUIRES( + context, + TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2, + errors::InvalidArgument("paddings must be a matrix with 2 columns: ", + in1.shape().DebugString())); + const int fixed_dims = + (kAllowLegacyScalars && dims == 0 && in1.dim_size(0) == 1) ? 1 : dims; + OP_REQUIRES( + context, fixed_dims == in1.dim_size(0), + errors::InvalidArgument( + "The first dimension of paddings must be the rank of inputs", + in1.shape().DebugString(), " ", in0.shape().DebugString())); + + // Compute the shape of the output tensor, and allocate it. + TensorShape output_shape; + TTypes::ConstMatrix paddings = in1.matrix(); + for (int d = 0; d < fixed_dims; ++d) { + const int32 before_d = paddings(d, 0); // Pad before existing elements. + const int32 after_d = paddings(d, 1); // Pad after exisitng elements. + OP_REQUIRES(context, before_d >= 0 && after_d >= 0, + errors::InvalidArgument("Paddings must be non-negative: ", + before_d, " ", after_d)); + const int size_d = + (kAllowLegacyScalars && d == in0.dims()) ? 1 : in0.dim_size(d); + output_shape.AddDim(before_d + size_d + after_d); + } + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + + // Invoke the dims-specific implementation. + switch (fixed_dims) { + case 0: + Operate<0>(context, in0.tensor(), paddings, output); + break; + case 1: + // TODO(irving): Once Pad doesn't need a scalar special case, + // change flat to tensor. That is, once !kAllowLegacyScalars. + Operate<1>(context, in0.flat(), paddings, output); + break; + case 2: + Operate<2>(context, in0.tensor(), paddings, output); + break; + case 3: + Operate<3>(context, in0.tensor(), paddings, output); + break; + case 4: + Operate<4>(context, in0.tensor(), paddings, output); + break; + case 5: + Operate<5>(context, in0.tensor(), paddings, output); + break; + default: + OP_REQUIRES(context, false, + errors::InvalidArgument("Only ranks up to 5 supported: ", + in0.shape().DebugString())); + } + } + + private: + template + void Operate(OpKernelContext* context, + typename TTypes::ConstTensor input, + TTypes::ConstMatrix paddings, Tensor* output) { + CHECK_EQ(Dims, paddings.dimension(0)); + CHECK_EQ(2, paddings.dimension(1)); + Eigen::array, Dims> paddings_array; + for (int i = 0; i < Dims; ++i) { + paddings_array[i] = std::make_pair(paddings(i, 0), paddings(i, 1)); + } + functor::Pad functor; + functor(context->eigen_device(), output->tensor(), input, + paddings_array); + } +}; + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Pad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("paddings"), \ + PadOp) + +TF_CALL_ALL_TYPES(REGISTER_KERNEL); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T, Dims) \ + template <> \ + void Pad::operator()( \ + const GPUDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input, \ + Eigen::array, Dims> paddings); \ + extern template struct Pad; + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPEC(T, 0); \ + DECLARE_GPU_SPEC(T, 1); \ + DECLARE_GPU_SPEC(T, 2); \ + DECLARE_GPU_SPEC(T, 3); \ + DECLARE_GPU_SPEC(T, 4); \ + DECLARE_GPU_SPEC(T, 5); + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("Pad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("paddings"), \ + PadOp) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); +#endif // GOOGLE_CUDA + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/pad_op.h b/tensorflow/core/kernels/pad_op.h new file mode 100644 index 0000000000..c4f8a4abda --- /dev/null +++ b/tensorflow/core/kernels/pad_op.h @@ -0,0 +1,27 @@ +#ifndef TENSORFLOW_KERNELS_PAD_OP_H_ +#define TENSORFLOW_KERNELS_PAD_OP_H_ +// Functor definition for PadOp, must be compilable by nvcc. + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by PadOp to do the computations. +template +struct Pad { + // Pad "input" into "output", as specified by "paddings". See pad_op.cc for + // details. + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + Eigen::array, Dims> paddings) { + output.device(d) = input.pad(paddings); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_PAD_OP_H_ diff --git a/tensorflow/core/kernels/pad_op_gpu.cu.cc b/tensorflow/core/kernels/pad_op_gpu.cu.cc new file mode 100644 index 0000000000..35a03a2cb2 --- /dev/null +++ b/tensorflow/core/kernels/pad_op_gpu.cu.cc @@ -0,0 +1,26 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/pad_op.h" + +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// Definition of the GPU implementations declared in pad_op.cc. +#define DEFINE_GPU_SPECS(T) \ + template struct functor::Pad; \ + template struct functor::Pad; \ + template struct functor::Pad; \ + template struct functor::Pad; \ + template struct functor::Pad; \ + template struct functor::Pad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc new file mode 100644 index 0000000000..35e9bd75fa --- /dev/null +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -0,0 +1,252 @@ +#include "tensorflow/core/kernels/pooling_ops_common.h" + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/public/tensor.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/maxpooling_op_gpu.h" +#include "tensorflow/core/kernels/pooling_ops_common_gpu.h" +#include "tensorflow/stream_executor/dnn.h" +#include "tensorflow/stream_executor/stream.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +PoolParameters::PoolParameters(OpKernelContext* context, + const std::vector& ksize, + const std::vector& stride, + Padding padding, + const TensorShape& tensor_in_shape) { + // For maxpooling, tensor_in should have 4 dimensions. + OP_REQUIRES(context, tensor_in_shape.dims() == 4, + errors::InvalidArgument("tensor_in must be 4-dimensional")); + + depth = tensor_in_shape.dim_size(3); + tensor_in_cols = tensor_in_shape.dim_size(2); + tensor_in_rows = tensor_in_shape.dim_size(1); + tensor_in_batch = tensor_in_shape.dim_size(0); + window_rows = ksize[1]; + window_cols = ksize[2]; + depth_window = ksize[3]; + row_stride = stride[1]; + col_stride = stride[2]; + depth_stride = stride[3]; + + // We only support 2D pooling across width/height and depthwise + // pooling, not a combination. + OP_REQUIRES(context, + (depth_window == 1 || (window_rows == 1 && window_cols == 1)), + errors::Unimplemented( + "MaxPooling supports exactly one of pooling across depth " + "or pooling across width/height.")); + + if (depth_window == 1) { + OP_REQUIRES_OK(context, Get2dOutputSize( + tensor_in_rows, tensor_in_cols, window_rows, + window_cols, row_stride, col_stride, padding, + &out_height, &out_width, &pad_rows, &pad_cols)); + } else { + // Our current version of depthwise max pooling does not support + // any padding, and expects the depth_window to equal the + // depth_stride (no overlapping). + OP_REQUIRES( + context, depth % depth_window == 0, + errors::Unimplemented("Depthwise max pooling requires the depth " + "window to evenly divide the input depth")); + OP_REQUIRES( + context, depth_stride == depth_window, + errors::Unimplemented("Depthwise max pooling requires the depth " + "window to equal the depth stride")); + + // The current version of depthwise max is only implemented on CPU. + OP_REQUIRES(context, + (DeviceType(static_cast(context->device()) + ->attributes() + .device_type()) == DeviceType(DEVICE_CPU)), + errors::Unimplemented("Depthwise max pooling is currently " + "only implemented for CPU devices.")); + + pad_depth = 0; + out_depth = depth / depth_window; + } +} + +TensorShape PoolParameters::forward_output_shape() { + if (depth_window == 1) { + // Spatial pooling + return TensorShape({tensor_in_batch, out_height, out_width, depth}); + } else { + // Depthwise pooling + return TensorShape( + {tensor_in_batch, tensor_in_rows, tensor_in_cols, out_depth}); + } +} + +#ifdef GOOGLE_CUDA + +namespace { +template +perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory, + uint64 size) { + perftools::gputools::DeviceMemoryBase wrapped(const_cast(cuda_memory), + size * sizeof(T)); + perftools::gputools::DeviceMemory typed(wrapped); + return typed; +} +} // namespace + +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void TransformDepth::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + const Eigen::DSizes& shuffle, \ + typename TTypes::Tensor out); \ + extern template struct TransformDepth; + +DECLARE_GPU_SPEC(float); +#undef DECLARE_GPU_SPEC +} // namespace functor + +template +void DnnPoolingGradOp::Compute( + OpKernelContext* context, + perftools::gputools::dnn::PoolingMode pooling_mode, + const std::vector& size, const std::vector& stride, + Padding padding, const Tensor* tensor_in, const Tensor* tensor_out, + const Tensor& out_backprop, const TensorShape& tensor_in_shape) { + CHECK((pooling_mode == perftools::gputools::dnn::PoolingMode::kMaximum) || + (tensor_in && tensor_out)) + << "For MaxPoolGrad, both tensor_in and tensor_out needs to be " + "specified"; + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, tensor_in_shape, &output)); + + PoolParameters params{context, size, stride, padding, tensor_in_shape}; + if (!context->status().ok()) { + return; + } + + /// For now, cudnn does not support NHWC format, so we need to convert it + /// to NCHW before calling cudnn. We need to get rid of this once it is done + Tensor transformed_input; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({tensor_in_shape.dim_size(0), + tensor_in_shape.dim_size(3), + tensor_in_shape.dim_size(1), + tensor_in_shape.dim_size(2)}), + &transformed_input)); + Tensor transformed_input_backprop; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({tensor_in_shape.dim_size(0), + tensor_in_shape.dim_size(3), + tensor_in_shape.dim_size(1), + tensor_in_shape.dim_size(2)}), + &transformed_input_backprop)); + Tensor transformed_output; + OP_REQUIRES_OK( + context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({out_backprop.dim_size(0), out_backprop.dim_size(3), + out_backprop.dim_size(1), out_backprop.dim_size(2)}), + &transformed_output)); + Tensor transformed_output_backprop; + OP_REQUIRES_OK( + context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({out_backprop.dim_size(0), out_backprop.dim_size(3), + out_backprop.dim_size(1), out_backprop.dim_size(2)}), + &transformed_output_backprop)); + + auto nhwc_to_nchw = Eigen::DSizes(0, 3, 1, 2); + if (tensor_in) { + // For AvgPoolGrad, the original input tensor is not necessary. However, + // cudnn still requires them to run, although they do not affect the + // results. + functor::TransformDepth()( + context->eigen_device(), tensor_in->tensor(), + nhwc_to_nchw, transformed_input.tensor()); + } + if (tensor_out) { + // For AvgPoolGrad, the original output tensor is not necessary. However, + // cudnn still requires them to run, although they do not affect the + // results. + functor::TransformDepth()( + context->eigen_device(), tensor_out->tensor(), + nhwc_to_nchw, transformed_output.tensor()); + } + functor::TransformDepth()( + context->eigen_device(), out_backprop.tensor(), + nhwc_to_nchw, transformed_output_backprop.tensor()); + + /// Get ready to call cudnn + perftools::gputools::dnn::PoolingDescriptor pooling_desc; + pooling_desc.set_pooling_mode(pooling_mode) + .set_window_height(params.window_rows) + .set_window_width(params.window_cols) + .set_vertical_stride(params.row_stride) + .set_horizontal_stride(params.col_stride) + .set_vertical_padding(params.pad_rows) + .set_horizontal_padding(params.pad_cols); + + perftools::gputools::dnn::BatchDescriptor orig_output_desc; + orig_output_desc.set_count(params.tensor_in_batch) + .set_height(params.out_height) + .set_width(params.out_width) + .set_feature_map_count(params.depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + + perftools::gputools::dnn::BatchDescriptor orig_input_desc; + orig_input_desc.set_count(params.tensor_in_batch) + .set_height(params.tensor_in_rows) + .set_width(params.tensor_in_cols) + .set_feature_map_count(params.depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + + auto orig_output_data = + AsDeviceMemory(transformed_output.template flat().data(), + transformed_output.template flat().size()); + auto orig_input_data = + AsDeviceMemory(transformed_input.template flat().data(), + transformed_input.template flat().size()); + auto output_backprop = + AsDeviceMemory(transformed_output_backprop.template flat().data(), + transformed_output_backprop.template flat().size()); + auto input_backprop = + AsDeviceMemory(transformed_input_backprop.template flat().data(), + transformed_input_backprop.template flat().size()); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + bool status = + stream->ThenPoolBackward(pooling_desc, orig_input_desc, orig_input_data, + orig_output_desc, orig_output_data, + output_backprop, &input_backprop) + .ok(); + OP_REQUIRES(context, status, + errors::Internal("cudnn PoolBackward launch failed")); + + /// Transform the output data from NCHW back to NHWC + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + auto nchw_to_nhwc = Eigen::DSizes(0, 2, 3, 1); + functor::TransformDepth()( + context->eigen_device(), + toConstTensor(transformed_input_backprop).template tensor(), + nchw_to_nhwc, output->tensor()); +} + +template class DnnPoolingGradOp; + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h new file mode 100644 index 0000000000..5bf44b6e40 --- /dev/null +++ b/tensorflow/core/kernels/pooling_ops_common.h @@ -0,0 +1,264 @@ +#ifndef TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_ +#define TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_ + +#include + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/avgpooling_op.h" +#include "tensorflow/core/kernels/maxpooling_op.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// A helper class to manage sizes and shapes for pooling operations. +struct PoolParameters { + // Updates context->status if there is an invalid input. + PoolParameters(OpKernelContext* context, const std::vector& ksize, + const std::vector& stride, Padding padding, + const TensorShape& tensor_in_shape); + + // Returns the shape of the output for "forward" pooling operations. + TensorShape forward_output_shape(); + + int depth; + + int tensor_in_cols; + int tensor_in_rows; + int tensor_in_batch; + + int window_rows; + int window_cols; + int depth_window; + + int row_stride; + int col_stride; + int depth_stride; + + int out_height; + int out_width; + int out_depth; + + int pad_rows; + int pad_cols; + int pad_depth; +}; + +// An implementation of MaxPooling (forward). +template +class MaxPoolingOp : public UnaryOp { + public: + explicit MaxPoolingOp(OpKernelConstruction* context) : UnaryOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + PoolParameters params{context, ksize_, stride_, padding_, + tensor_in.shape()}; + if (!context->status().ok()) { + return; + } + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 0, params.forward_output_shape(), &output)); + + if (params.depth_window > 1) { + DepthwiseMaxPool(context, output, tensor_in, params); + } else { + SpatialMaxPool(context, output, tensor_in, params, padding_); + } + } + + private: + // Single-threaded implementation of DepthwiseMaxPool which + // does not handle all of the same options as SpatialMaxPool + // (strict assumptions on no padding, stride). + // + // TODO(vrv): implement a more general depthwise-max pool that works + // on GPU as well. + void DepthwiseMaxPool(OpKernelContext* context, Tensor* output, + const Tensor& tensor_in, const PoolParameters& params) { + Eigen::Map> + in_by_pool(tensor_in.flat().data(), params.depth_window, + tensor_in.NumElements() / params.depth_window); + Eigen::Map> out_by_pool( + output->flat().data(), 1, output->NumElements()); + out_by_pool = in_by_pool.colwise().maxCoeff(); + } + + void SpatialMaxPool(OpKernelContext* context, Tensor* output, + const Tensor& tensor_in, const PoolParameters& params, + const Padding& padding) { + // On GPU, use Eigen's Spatial Max Pooling. On CPU, use an + // EigenMatrix version that is currently faster than Eigen's + // Spatial MaxPooling implementation. + // + // TODO(vrv): Remove this once we no longer need it. + if (std::is_same::value) { + Eigen::PaddingType pt = BrainPadding2EigenPadding(padding); + functor::SpatialMaxPooling()( + context->eigen_device(), output->tensor(), + tensor_in.tensor(), params.window_rows, params.window_cols, + params.row_stride, params.col_stride, pt); + } else { + typedef Eigen::Map> + ConstEigenMatrixMap; + typedef Eigen::Map> + EigenMatrixMap; + + ConstEigenMatrixMap in_mat(tensor_in.flat().data(), params.depth, + params.tensor_in_cols * params.tensor_in_rows * + params.tensor_in_batch); + EigenMatrixMap out_mat( + output->flat().data(), params.depth, + params.out_width * params.out_height * params.tensor_in_batch); + + // Initializes the output tensor with MIN. + output->flat().setConstant(Eigen::NumTraits::lowest()); + + // The following code basically does the following: + // 1. Flattens the input and output tensors into two dimensional arrays. + // tensor_in_as_matrix: + // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch) + // output_as_matrix: + // depth by (out_width * out_height * tensor_in_batch) + // + // 2. Walks through the set of columns in the flattened + // tensor_in_as_matrix, + // and updates the corresponding column(s) in output_as_matrix with the + // max value. + for (int b = 0; b < params.tensor_in_batch; ++b) { + for (int h = 0; h < params.tensor_in_rows; ++h) { + for (int w = 0; w < params.tensor_in_cols; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + const int hpad = h + params.pad_rows; + const int wpad = w + params.pad_cols; + const int h_start = + (hpad < params.window_rows) + ? 0 + : (hpad - params.window_rows) / params.row_stride + 1; + const int h_end = + std::min(hpad / params.row_stride + 1, params.out_height); + const int w_start = + (wpad < params.window_cols) + ? 0 + : (wpad - params.window_cols) / params.col_stride + 1; + const int w_end = + std::min(wpad / params.col_stride + 1, params.out_width); + // compute elementwise max + const int in_offset = + (b * params.tensor_in_rows + h) * params.tensor_in_cols + w; + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + const int out_offset = + (b * params.out_height + ph) * params.out_width + pw; + out_mat.col(out_offset) = + out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset)); + } + } + } + } + } + } + } + + std::vector ksize_; + std::vector stride_; + Padding padding_; +}; + +template +void SpatialAvgPool(OpKernelContext* context, Tensor* output, + const Tensor& input, const PoolParameters& params, + const Padding& padding) { + typedef Eigen::Map> + ConstEigenMatrixMap; + typedef Eigen::Map> + EigenMatrixMap; + + auto in_flat = input.flat(); + auto out_flat = output->flat(); + + ConstEigenMatrixMap in_mat( + in_flat.data(), params.depth, + params.tensor_in_cols * params.tensor_in_rows * params.tensor_in_batch); + EigenMatrixMap out_mat( + out_flat.data(), params.depth, + params.out_width * params.out_height * params.tensor_in_batch); + Eigen::Matrix out_count(out_mat.cols()); + out_count.setZero(); + + // Initializes output to zero. + out_flat.setZero(); + + // The following code basically does the following: + // 1. Flattens the input and output tensors into two dimensional arrays. + // tensor_in_as_matrix: + // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch) + // output_as_matrix: + // depth by (out_width * out_height * tensor_in_batch) + // + // 2. Walks through the set of columns in the flattened + // tensor_in_as_matrix, + // and updates the corresponding column(s) in output_as_matrix with the + // average value. + for (int b = 0; b < params.tensor_in_batch; ++b) { + for (int h = 0; h < params.tensor_in_rows; ++h) { + for (int w = 0; w < params.tensor_in_cols; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + const int hpad = h + params.pad_rows; + const int wpad = w + params.pad_cols; + const int h_start = + (hpad < params.window_rows) + ? 0 + : (hpad - params.window_rows) / params.row_stride + 1; + const int h_end = + std::min(hpad / params.row_stride + 1, params.out_height); + const int w_start = + (wpad < params.window_cols) + ? 0 + : (wpad - params.window_cols) / params.col_stride + 1; + const int w_end = + std::min(wpad / params.col_stride + 1, params.out_width); + const int in_offset = + (b * params.tensor_in_rows + h) * params.tensor_in_cols + w; + Eigen::DSizes in_indices(0, in_offset); + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + const int out_offset = + (b * params.out_height + ph) * params.out_width + pw; + out_mat.col(out_offset) += in_mat.col(in_offset); + out_count(out_offset)++; + } + } + } + } + } + DCHECK_GT(out_count.minCoeff(), 0); + out_mat.array().rowwise() /= out_count.transpose().array(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_ diff --git a/tensorflow/core/kernels/pooling_ops_common_gpu.h b/tensorflow/core/kernels/pooling_ops_common_gpu.h new file mode 100644 index 0000000000..87a3ef5186 --- /dev/null +++ b/tensorflow/core/kernels/pooling_ops_common_gpu.h @@ -0,0 +1,39 @@ +#if !GOOGLE_CUDA +#error This file must only be included when building with Cuda support +#endif + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_GPU_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_GPU_H_ + +#include "tensorflow/stream_executor/dnn.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/avgpooling_op.h" +#include "tensorflow/core/kernels/maxpooling_op.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +// A helper class that launch the cudnn pooling backward operations. +// The original input and output tensors are optional for AvgPoolGrad, but +// mandatory for MaxPoolGrad. +template +class DnnPoolingGradOp { + public: + typedef GPUDevice Device; + static void Compute(OpKernelContext* context, + perftools::gputools::dnn::PoolingMode pooling_mode, + const std::vector& size, + const std::vector& stride, Padding padding, + const Tensor* tensor_in, const Tensor* tensor_out, + const Tensor& out_backprop, + const TensorShape& tensor_in_shape); +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_GPU_H_ diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc new file mode 100644 index 0000000000..1b13f68a3a --- /dev/null +++ b/tensorflow/core/kernels/queue_base.cc @@ -0,0 +1,153 @@ +#include "tensorflow/core/kernels/queue_base.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +namespace { + +template +void HandleSliceToElement(const Tensor& parent, Tensor* element, int index) { + typedef typename EnumToDataType
::Type T; + auto parent_as_matrix = parent.flat_outer_dims(); + element->flat() = parent_as_matrix.chip(index, 0); +} + +template +void HandleElementToSlice(const Tensor& element, Tensor* parent, int index) { + typedef typename EnumToDataType
::Type T; + auto parent_as_matrix = parent->flat_outer_dims(); + parent_as_matrix.chip(index, 0) = element.flat(); +} + +} // namespace + +// static +Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element, + int index) { +#define HANDLE_TYPE(DT) \ + if (parent.dtype() == DT) { \ + HandleSliceToElement
(parent, element, index); \ + return Status::OK(); \ + } + HANDLE_TYPE(DT_FLOAT); + HANDLE_TYPE(DT_DOUBLE); + HANDLE_TYPE(DT_INT32); + HANDLE_TYPE(DT_UINT8); + HANDLE_TYPE(DT_INT16); + HANDLE_TYPE(DT_INT8); + HANDLE_TYPE(DT_STRING); + HANDLE_TYPE(DT_INT64); +#undef HANDLE_TYPE + return errors::Unimplemented("Unhandled data type: ", parent.dtype()); +} + +// static +Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent, + int index) { +#define HANDLE_TYPE(DT) \ + if (element.dtype() == DT) { \ + HandleElementToSlice
(element, parent, index); \ + return Status::OK(); \ + } + HANDLE_TYPE(DT_FLOAT); + HANDLE_TYPE(DT_DOUBLE); + HANDLE_TYPE(DT_INT32); + HANDLE_TYPE(DT_UINT8); + HANDLE_TYPE(DT_INT16); + HANDLE_TYPE(DT_INT8); + HANDLE_TYPE(DT_STRING); + HANDLE_TYPE(DT_INT64); +#undef HANDLE_TYPE + return errors::Unimplemented("Unhandled data type: ", element.dtype()); +} + +QueueBase::QueueBase(const DataTypeVector& component_dtypes, + const std::vector& component_shapes, + const string& name) + : component_dtypes_(component_dtypes), + component_shapes_(component_shapes), + name_(name) {} + +Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const { + if (tuple.size() != static_cast(num_components())) { + return errors::InvalidArgument( + "Wrong number of components in tuple. Expected ", num_components(), + ", got ", tuple.size()); + } + for (size_t i = 0; i < tuple.size(); ++i) { + if (tuple[i].dtype() != component_dtypes_[i]) { + return errors::InvalidArgument( + "Type mismatch in tuple component ", i, ". Expected ", + DataTypeString(component_dtypes_[i]), ", got ", + DataTypeString(tuple[i].dtype())); + } + } + return Status::OK(); +} + +// static +string QueueBase::ShapeListString(const gtl::ArraySlice& shapes) { + string result = "["; + bool first = true; + for (const TensorShape& shape : shapes) { + strings::StrAppend(&result, (first ? "" : ", "), shape.ShortDebugString()); + first = false; + } + strings::StrAppend(&result, "]"); + return result; +} + +Status QueueBase::MatchesNodeDefOp(const NodeDef& node_def, + const string& op) const { + if (node_def.op() != op) { + return errors::InvalidArgument("Shared queue '", name_, "' has type '", op, + "' that does not match type of Node '", + node_def.name(), "': ", node_def.op()); + } + return Status::OK(); +} + +Status QueueBase::MatchesNodeDefCapacity(const NodeDef& node_def, + int32 capacity) const { + int32 requested_capacity = -1; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "capacity", &requested_capacity)); + if (requested_capacity < 0) requested_capacity = kUnbounded; + if (requested_capacity != capacity) { + return errors::InvalidArgument("Shared queue '", name_, "' has capacity ", + capacity, " but requested capacity was ", + requested_capacity); + } + return Status::OK(); +} + +Status QueueBase::MatchesNodeDefTypes(const NodeDef& node_def) const { + DataTypeVector requested_dtypes; + TF_RETURN_IF_ERROR( + GetNodeAttr(node_def, "component_types", &requested_dtypes)); + if (requested_dtypes != component_dtypes_) { + return errors::InvalidArgument("Shared queue '", name_, + "' has component types ", + DataTypeSliceString(component_dtypes_), + " but requested component types were ", + DataTypeSliceString(requested_dtypes)); + } + return Status::OK(); +} + +Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const { + std::vector requested_shapes; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes)); + if (requested_shapes != component_shapes_) { + return errors::InvalidArgument("Shared queue '", name_, + "' has component shapes ", + ShapeListString(component_shapes_), + " but requested component shapes were ", + ShapeListString(requested_shapes)); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h new file mode 100644 index 0000000000..4897102974 --- /dev/null +++ b/tensorflow/core/kernels/queue_base.h @@ -0,0 +1,77 @@ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/queue_interface.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +// Functionality common to QueueInterface implementations. +class QueueBase : public QueueInterface { + public: + // As a possible value of 'capacity'. + static const int32 kUnbounded = INT_MAX; + + // Args: + // component_dtypes: The types of each component in a queue-element tuple. + // component_shapes: The shapes of each component in a queue-element tuple, + // which must either be empty (if the shapes are not specified) or + // or have the same size as component_dtypes. + // name: A name to use for the queue. + QueueBase(const DataTypeVector& component_dtypes, + const std::vector& component_shapes, + const string& name); + + // Implementations of QueueInterface methods -------------------------------- + const DataTypeVector& component_dtypes() const override { + return component_dtypes_; + } + + // Other public methods ----------------------------------------------------- + const std::vector& component_shapes() const { + return component_shapes_; + } + + protected: + // Returns the number of components in a queue-element tuple. + int32 num_components() const { return component_dtypes_.size(); } + + // True if shapes were specified. If so, inputs will be validated + // against them, etc. + bool specified_shapes() const { return component_shapes_.size() > 0; } + + // Code common to Validate*Tuple(). + Status ValidateTupleCommon(const Tuple& tuple) const; + + // Copies the index^th slice (in the first dimension) of parent into element. + static Status CopySliceToElement(const Tensor& parent, Tensor* element, + int index); + + // Copies element into the index^th slice (in the first dimension) of parent. + static Status CopyElementToSlice(const Tensor& element, Tensor* parent, + int index); + + ~QueueBase() override {} + + // Helpers for implementing MatchesNodeDef(). + static string ShapeListString(const gtl::ArraySlice& shapes); + Status MatchesNodeDefOp(const NodeDef& node_def, const string& op) const; + Status MatchesNodeDefCapacity(const NodeDef& node_def, int32 capacity) const; + Status MatchesNodeDefTypes(const NodeDef& node_def) const; + Status MatchesNodeDefShapes(const NodeDef& node_def) const; + + const DataTypeVector component_dtypes_; + const std::vector component_shapes_; + const string name_; + + TF_DISALLOW_COPY_AND_ASSIGN(QueueBase); +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc new file mode 100644 index 0000000000..c70dc76777 --- /dev/null +++ b/tensorflow/core/kernels/queue_ops.cc @@ -0,0 +1,288 @@ +// See docs in ../ops/data_flow_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/queue_interface.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +class QueueOpKernel : public AsyncOpKernel { + public: + explicit QueueOpKernel(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { + QueueInterface* queue; + OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue), + callback); + ComputeAsync(ctx, queue, [callback, queue]() { + queue->Unref(); + callback(); + }); + } + + protected: + virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) = 0; +}; + +class QueueAccessOpKernel : public QueueOpKernel { + public: + explicit QueueAccessOpKernel(OpKernelConstruction* context) + : QueueOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_)); + // TODO(keveman): Enable timeout. + OP_REQUIRES(context, timeout_ == -1, + errors::InvalidArgument("Timeout not supported yet.")); + } + + protected: + int64 timeout_; +}; + +// Defines an EnqueueOp, the execution of which enqueues a tuple of +// tensors in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +class EnqueueOp : public QueueAccessOpKernel { + public: + explicit EnqueueOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + DataTypeVector expected_inputs = {DT_STRING_REF}; + for (DataType dt : queue->component_dtypes()) { + expected_inputs.push_back(dt); + } + OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), + callback); + + QueueInterface::Tuple tuple; + OpInputList components; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), + callback); + for (const Tensor& Tcomponent : components) { + tuple.push_back(Tcomponent); + } + + OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback); + queue->TryEnqueue(tuple, ctx, callback); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp); + +// Defines an EnqueueManyOp, the execution of which slices each +// component of a tuple of tensors along the 0th dimension, and +// enqueues tuples of slices in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +// +// N.B. All tuple components must have the same size in the 0th +// dimension. +class EnqueueManyOp : public QueueAccessOpKernel { + public: + explicit EnqueueManyOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + DataTypeVector expected_inputs = {DT_STRING_REF}; + for (DataType dt : queue->component_dtypes()) { + expected_inputs.push_back(dt); + } + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); + + QueueInterface::Tuple tuple; + OpInputList components; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), + callback); + for (const Tensor& Tcomponent : components) { + tuple.push_back(Tcomponent); + } + + OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback); + queue->TryEnqueueMany(tuple, ctx, callback); + } + + ~EnqueueManyOp() override {} + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU), + EnqueueManyOp); + +// Defines a DequeueOp, the execution of which dequeues a tuple of +// tensors from the given Queue. +// +// The op has one input, which is the handle of the appropriate +// Queue. The op has k outputs, where k is the number of components in +// the tuples stored in the given Queue, and output i is the ith +// component of the dequeued tuple. +class DequeueOp : public QueueAccessOpKernel { + public: + explicit DequeueOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + OP_REQUIRES_OK_ASYNC( + ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()), + callback); + + queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { + if (!ctx->status().ok()) { + callback(); + return; + } + OpOutputList output_components; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->output_list("components", &output_components), callback); + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_components.set(i, tuple[i]); + } + callback(); + }); + } + + ~DequeueOp() override {} + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp); + +// Defines a DequeueManyOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +class DequeueManyOp : public QueueAccessOpKernel { + public: + explicit DequeueManyOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + const Tensor& Tnum_elements = ctx->input(1); + int32 num_elements = Tnum_elements.flat()(0); + + OP_REQUIRES_ASYNC( + ctx, num_elements >= 0, + errors::InvalidArgument("DequeueManyOp must request a positive number " + "of elements"), + callback); + + OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32}, + queue->component_dtypes()), + callback); + + queue->TryDequeueMany( + num_elements, ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { + if (!ctx->status().ok()) { + callback(); + return; + } + OpOutputList output_components; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->output_list("components", &output_components), + callback); + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_components.set(i, tuple[i]); + } + callback(); + }); + } + + ~DequeueManyOp() override {} + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU), + DequeueManyOp); + +// Defines a QueueCloseOp, which closes the given Queue. Closing a +// Queue signals that no more elements will be enqueued in it. +// +// The op has one input, which is the handle of the appropriate Queue. +class QueueCloseOp : public QueueOpKernel { + public: + explicit QueueCloseOp(OpKernelConstruction* context) + : QueueOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues", + &cancel_pending_enqueues_)); + } + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + queue->Close(ctx, cancel_pending_enqueues_, callback); + } + + private: + bool cancel_pending_enqueues_; + TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp); + +// Defines a QueueSizeOp, which computes the number of elements in the +// given Queue, and emits it as an output tensor. +// +// The op has one input, which is the handle of the appropriate Queue; +// and one output, which is a single-element tensor containing the current +// size of that Queue. +class QueueSizeOp : public QueueOpKernel { + public: + explicit QueueSizeOp(OpKernelConstruction* context) + : QueueOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + Tensor* Tqueue_size = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size)); + Tqueue_size->flat().setConstant(queue->size()); + callback(); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/random_crop_op.cc b/tensorflow/core/kernels/random_crop_op.cc new file mode 100644 index 0000000000..4fc12e92cb --- /dev/null +++ b/tensorflow/core/kernels/random_crop_op.cc @@ -0,0 +1,103 @@ +// See docs in ../ops/image_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/guarded_philox_random.h" + +namespace tensorflow { + +template +class RandomCropOp : public OpKernel { + public: + explicit RandomCropOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, generator_.Init(context)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + OP_REQUIRES(context, input.dims() == 3, + errors::InvalidArgument("input must be 3-dimensional", + input.shape().ShortDebugString())); + const Tensor& shape_t = context->input(1); + OP_REQUIRES(context, shape_t.dims() == 1, + errors::InvalidArgument("shape_t must be 1-dimensional", + shape_t.shape().ShortDebugString())); + OP_REQUIRES(context, shape_t.NumElements() == 2, + errors::InvalidArgument("shape_t must have two elements", + shape_t.shape().ShortDebugString())); + + auto shape_vec = shape_t.vec(); + const int32 target_height = shape_vec(0); + const int32 target_width = shape_vec(1); + + const int32 height = input.dim_size(0); + const int32 width = input.dim_size(1); + const int32 channels = input.dim_size(2); + + // Initialize shape to the batch size of the input, then add + // the rest of the dimensions + Tensor* output = nullptr; + const auto output_shape = + TensorShape({target_height, target_width, channels}); + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + + // If the target size matches the actual size, then do nothing. + if ((target_height == height) && (target_width == width)) { + *output = context->input(0); + } + + // TODO(shlens): Implement edge case to guarantee output size dimensions. + // Edge case. The target dimensions are larger then the image, so + // zero-pad the image. This guarantees that the image will *always* + // be [target_height, target_width] in size. + OP_REQUIRES(context, width >= target_width, errors::FailedPrecondition( + "width must be >= target_width: width = ", width, + ", target_width = ", target_width)); + OP_REQUIRES(context, height >= target_height, errors::FailedPrecondition( + "height must be >= target_height: height = ", height, + ", target_height = ", target_height)); + + int32 offset_height = 0; + int32 offset_width = 0; + + auto local_gen = generator_.ReserveSamples32(2); + random::SimplePhilox random(&local_gen); + + if (width > target_width) { + offset_width = random.Rand32() % (width - target_width + 1); + } + if (height > target_height) { + offset_height = random.Rand32() % (height - target_height + 1); + } + + // TODO(shlens): Do this more efficiently with memcpy once padding is + // available for smaller images. + typename TTypes::ConstTensor input_data = input.tensor(); + typename TTypes::Tensor output_data = output->tensor(); + + for (int y = 0; y < target_height; ++y) { + for (int x = 0; x < target_width; ++x) { + for (int c = 0; c < channels; ++c) { + output_data(y, x, c) = + input_data(y + offset_height, x + offset_width, c); + } + } + } + } + + private: + GuardedPhiloxRandom generator_; +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomCrop").Device(DEVICE_CPU).TypeConstraint("T"), \ + RandomCropOp) + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/random_crop_op_test.cc b/tensorflow/core/kernels/random_crop_op_test.cc new file mode 100644 index 0000000000..1f232f4969 --- /dev/null +++ b/tensorflow/core/kernels/random_crop_op_test.cc @@ -0,0 +1,60 @@ +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/public/tensor.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { + +class RandomCropOpTest : public OpsTestBase { + protected: + RandomCropOpTest() { + RequireDefaultOps(); + EXPECT_OK(NodeDefBuilder("random_crop_op", "RandomCrop") + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_INT64)) + .Attr("T", DT_UINT8) + .Finalize(node_def())); + EXPECT_OK(InitOp()); + } +}; + +TEST_F(RandomCropOpTest, Basic) { + AddInputFromArray(TensorShape({1, 2, 1}), {2, 2}); + AddInputFromArray(TensorShape({2}), {1, 1}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_UINT8, TensorShape({1, 1, 1})); + test::FillValues(&expected, {2}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RandomCropOpTest, SameSizeOneChannel) { + AddInputFromArray(TensorShape({2, 1, 1}), {1, 2}); + AddInputFromArray(TensorShape({2}), {2, 1}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_UINT8, TensorShape({2, 1, 1})); + test::FillValues(&expected, {1, 2}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RandomCropOpTest, SameSizeMultiChannel) { + AddInputFromArray(TensorShape({2, 1, 3}), {1, 2, 3, 4, 5, 6}); + AddInputFromArray(TensorShape({2}), {2, 1}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_UINT8, TensorShape({2, 1, 3})); + test::FillValues(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc new file mode 100644 index 0000000000..09b66d30e6 --- /dev/null +++ b/tensorflow/core/kernels/random_op.cc @@ -0,0 +1,276 @@ +// See docs in ../ops/random_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/random_op.h" + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/util/guarded_philox_random.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +// The default implementation of the functor, which should never be invoked +// But we still need to provide implementation for now for the linker to work, +// since we do not support all the distributions yet. +template +struct FillPhiloxRandom { + typedef typename Distribution::ResultElementType T; + void operator()(OpKernelContext*, const Device&, random::PhiloxRandom gen, + T* data, int64 size) { + LOG(FATAL) << "Default FillPhiloxRandom should not be executed."; + } +}; + +#if GOOGLE_CUDA +// Declaration for the partial specialization with GPU +template +struct FillPhiloxRandom { + typedef typename Distribution::ResultElementType T; + void operator()(OpKernelContext* ctx, const GPUDevice&, + random::PhiloxRandom gen, T* data, int64 size); +}; + +#endif + +// A class to fill a specified range of random groups +template +struct FillPhiloxRandomTask; + +// Specialization for distribution that takes a fixed number of samples for +// each output. +template +struct FillPhiloxRandomTask { + typedef typename Distribution::ResultElementType T; + static void Run(random::PhiloxRandom gen, T* data, int64 size, + int64 start_group, int64 limit_group) { + Distribution dist; + const int kGroupSize = Distribution::kResultElementCount; + + gen.Skip(start_group); + int64 offset = start_group * kGroupSize; + + // First fill all the full-size groups + int64 limit_group_full = std::min(limit_group, size / kGroupSize); + for (int64 index = start_group; index < limit_group_full; ++index) { + auto samples = dist(&gen); + std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); + offset += kGroupSize; + } + + // If there are any remaining elements that need to be filled, process them + if (limit_group_full < limit_group) { + int remaining_size = size - limit_group_full * kGroupSize; + auto samples = dist(&gen); + std::copy(&samples[0], &samples[0] + remaining_size, data + offset); + } + } +}; + +// Specialization for distribution that takes a varaiable number of samples for +// each output. This will be slower due to the generality. +template +struct FillPhiloxRandomTask { + typedef typename Distribution::ResultElementType T; + static const int64 kReservedSamplesPerOutput = 256; + + static void Run(random::PhiloxRandom base_gen, T* data, int64 size, + int64 start_group, int64 limit_group) { + using random::PhiloxRandom; + using random::SingleSampleAdapter; + + Distribution dist; + const int kGroupSize = Distribution::kResultElementCount; + + static const int kGeneratorSkipPerOutputGroup = + kGroupSize * kReservedSamplesPerOutput / + PhiloxRandom::kResultElementCount; + + int64 offset = start_group * kGroupSize; + + // First fill all the full-size groups + int64 limit_group_full = std::min(limit_group, size / kGroupSize); + int64 group_index; + for (group_index = start_group; group_index < limit_group_full; + ++group_index) { + // Reset the generator to the beginning of the output group region + // This is necessary if we want the results to be independent of order + // of work + PhiloxRandom gen = base_gen; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + auto samples = dist(&single_samples); + std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); + offset += kGroupSize; + } + + // If there are any remaining elements that need to be filled, process them + if (limit_group_full < limit_group) { + PhiloxRandom gen = base_gen; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + int remaining_size = size - limit_group_full * kGroupSize; + auto samples = dist(&single_samples); + std::copy(&samples[0], &samples[0] + remaining_size, data + offset); + } + } +}; + +// Partial specialization for CPU to fill the entire region with randoms +// It splits the work into several tasks and run them in parallel +template +struct FillPhiloxRandom { + typedef typename Distribution::ResultElementType T; + void operator()(OpKernelContext* context, const CPUDevice&, + random::PhiloxRandom gen, T* data, int64 size) { + const int kGroupSize = Distribution::kResultElementCount; + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + int64 total_group_count = (size + kGroupSize - 1) / kGroupSize; + + // Limit to maximum six threads for now. The performance scaling is very + // sub-linear. Too many threads causes a much worse overall performance. + int num_workers = 6; + Shard(num_workers, worker_threads.workers, total_group_count, kGroupSize, + [&gen, data, size](int64 start_group, int64 limit_group) { + FillPhiloxRandomTask< + Distribution, + Distribution::kVariableSamplesPerOutput>::Run(gen, data, size, + start_group, + limit_group); + }); + } +}; +} // namespace functor + +// For now, use the same interface as RandomOp, so we can choose either one +// at the run-time. +template +class PhiloxRandomOp : public OpKernel { + public: + typedef typename Distribution::ResultElementType T; + explicit PhiloxRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, generator_.Init(ctx)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& input = ctx->input(0); + OP_REQUIRES( + ctx, TensorShapeUtils::IsLegacyVector(input.shape()), + errors::InvalidArgument("shape must be a vector of {int32,int64}.")); + Tensor* output = nullptr; + if (input.dtype() == DataType::DT_INT32) { + auto vec = input.flat(); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShapeUtils::MakeShape( + vec.data(), vec.size()), + &output)); + } else if (input.dtype() == DataType::DT_INT64) { + auto vec = input.flat(); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShapeUtils::MakeShape( + vec.data(), vec.size()), + &output)); + } else { + OP_REQUIRES(ctx, false, errors::InvalidArgument( + "shape must be a vector of {int32,int64}.")); + } + functor::FillPhiloxRandom()( + ctx, ctx->eigen_device(), + ReserveRandomOutputs(output->flat().size()), + output->flat().data(), output->flat().size()); + } + + private: + GuardedPhiloxRandom generator_; + + // Reserve enough random samples in the generator for the given output count. + random::PhiloxRandom ReserveRandomOutputs(int64 output_count) { + int64 conservative_sample_count = output_count << 8; + return generator_.ReserveSamples128(conservative_sample_count); + } +}; + +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomUniform") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp >); \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomStandardNormal") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp >); \ + REGISTER_KERNEL_BUILDER( \ + Name("TruncatedNormal") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp< \ + CPUDevice, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter, TYPE> >) + +REGISTER(float); +REGISTER(double); + +#undef REGISTER + +#if GOOGLE_CUDA + +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomUniform") \ + .Device(DEVICE_GPU) \ + .HostMemory("shape") \ + .TypeConstraint("T") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp >); \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomStandardNormal") \ + .Device(DEVICE_GPU) \ + .HostMemory("shape") \ + .TypeConstraint("T") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp >); \ + REGISTER_KERNEL_BUILDER( \ + Name("TruncatedNormal") \ + .Device(DEVICE_GPU) \ + .HostMemory("shape") \ + .TypeConstraint("T") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp< \ + GPUDevice, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter, TYPE> >) + +REGISTER(float); +REGISTER(double); + +#undef REGISTER + +#endif // GOOGLE_CUDA + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/random_op.h b/tensorflow/core/kernels/random_op.h new file mode 100644 index 0000000000..7c7eed4227 --- /dev/null +++ b/tensorflow/core/kernels/random_op.h @@ -0,0 +1,16 @@ +#ifndef TENSORFLOW_KERNELS_RANDOM_OP_H_ +#define TENSORFLOW_KERNELS_RANDOM_OP_H_ + +namespace tensorflow { + +class OpKernelContext; + +namespace functor { + +template +struct FillPhiloxRandom; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_RANDOM_OP_H_ diff --git a/tensorflow/core/kernels/random_op_gpu.cu.cc b/tensorflow/core/kernels/random_op_gpu.cu.cc new file mode 100644 index 0000000000..15cf85f27e --- /dev/null +++ b/tensorflow/core/kernels/random_op_gpu.cu.cc @@ -0,0 +1,152 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/random_op.h" + +#include +#include + +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +class OpKernelContext; + +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; + +template +struct FillPhiloxRandomKernel; + +// A cuda kernel to fill the data with random numbers from the specified +// distribution. Each output takes a fixed number of samples. +template +struct FillPhiloxRandomKernel { + typedef typename Distribution::ResultElementType T; + PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size) { + Distribution dist; + const int kGroupSize = Distribution::kResultElementCount; + + const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int32 total_thread_count = gridDim.x * blockDim.x; + int32 offset = thread_id * kGroupSize; + gen.Skip(thread_id); + + while (offset < size) { + typename Distribution::ResultType samples = dist(&gen); + + for (int i = 0; i < kGroupSize; ++i) { + if (offset >= size) { + return; + } + data[offset] = samples[i]; + ++offset; + } + + offset += (total_thread_count - 1) * kGroupSize; + gen.Skip(total_thread_count - 1); + } + } +}; + +// A cuda kernel to fill the data with random numbers from the specified +// distribution. Each output takes a variable number of samples. +template +struct FillPhiloxRandomKernel { + typedef typename Distribution::ResultElementType T; + PHILOX_DEVICE_FUNC void Run(const random::PhiloxRandom& base_gen, T* data, + int64 size) { + using random::PhiloxRandom; + using random::SingleSampleAdapter; + + const int kReservedSamplesPerOutput = 256; + const int kGroupSize = Distribution::kResultElementCount; + const int kGeneratorSkipPerOutputGroup = kGroupSize * + kReservedSamplesPerOutput / + PhiloxRandom::kResultElementCount; + + const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int32 total_thread_count = gridDim.x * blockDim.x; + int64 group_index = thread_id; + int64 offset = group_index * kGroupSize; + Distribution dist; + + while (offset < size) { + // Since each output takes a variable number of samples, we need to + // realign the generator to the beginning for the current output group + PhiloxRandom gen = base_gen; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + typename Distribution::ResultType samples = dist(&single_samples); + + for (int i = 0; i < kGroupSize; ++i) { + if (offset >= size) { + return; + } + data[offset] = samples[i]; + ++offset; + } + + offset += (total_thread_count - 1) * kGroupSize; + group_index += total_thread_count; + } + } +}; + +// A simple launch pad to call the correct function templates to fill the data +template +__global__ void __launch_bounds__(1024) + FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen, + typename Distribution::ResultElementType* data, + int64 size) { + FillPhiloxRandomKernel() + .Run(base_gen, data, size); +} + +// Partial specialization for GPU +template +struct FillPhiloxRandom { + typedef typename Distribution::ResultElementType T; + typedef GPUDevice Device; + void operator()(OpKernelContext*, const Device& d, random::PhiloxRandom gen, + T* data, int64 size) { + const int32 block_size = d.maxCudaThreadsPerBlock(); + const int32 num_blocks = + (d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) / + block_size; + + FillPhiloxRandomKernelLaunch< + Distribution><<>>(gen, data, + size); + } +}; + +// Explicit instantiation of the GPU distributions functors +// clang-format off +// NVCC cannot handle ">>" properly +template struct FillPhiloxRandom< + GPUDevice, random::UniformDistribution >; +template struct FillPhiloxRandom< + GPUDevice, random::UniformDistribution >; +template struct FillPhiloxRandom< + GPUDevice, random::NormalDistribution >; +template struct FillPhiloxRandom< + GPUDevice, random::NormalDistribution >; +template struct FillPhiloxRandom< + GPUDevice, random::TruncatedNormalDistribution< + random::SingleSampleAdapter, float> >; +template struct FillPhiloxRandom< + GPUDevice, random::TruncatedNormalDistribution< + random::SingleSampleAdapter, double> >; +// clang-format on + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/random_op_test.cc b/tensorflow/core/kernels/random_op_test.cc new file mode 100644 index 0000000000..751b61cfba --- /dev/null +++ b/tensorflow/core/kernels/random_op_test.cc @@ -0,0 +1,99 @@ +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" +#include + +namespace tensorflow { + +Tensor Int32(int32 v) { + Tensor t(DT_INT32, TensorShape({})); + t.scalar()() = v; + return t; +} + +Graph* RandomUniform(int64 n) { + Graph* g = new Graph(OpRegistry::Global()); + test::graph::RandomUniform(g, test::graph::Constant(g, Int32(n)), DT_FLOAT); + return g; +} + +Graph* RandomNormal(int64 n) { + Graph* g = new Graph(OpRegistry::Global()); + test::graph::RandomGaussian(g, test::graph::Constant(g, Int32(n)), DT_FLOAT); + return g; +} + +Graph* RandomParameters(int64 n) { + Graph* g = new Graph(OpRegistry::Global()); + test::graph::RandomParameters(g, test::graph::Constant(g, Int32(n)), + DT_FLOAT); + return g; +} + +#define BM_RNG(DEVICE, RNG) \ + static void BM_##DEVICE##_##RNG(int iters, int arg) { \ + testing::ItemsProcessed(static_cast(iters) * arg); \ + test::Benchmark(#DEVICE, RNG(arg)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_##RNG)->Range(1 << 20, 8 << 20); + +BM_RNG(cpu, RandomUniform); +BM_RNG(cpu, RandomNormal); +BM_RNG(cpu, RandomParameters); + +BM_RNG(gpu, RandomUniform); +BM_RNG(gpu, RandomNormal); +BM_RNG(gpu, RandomParameters); + +static void BM_PhiloxRandom(int iters) { + // Fill 2M random numbers + int count = 2 << 20; + + testing::ItemsProcessed(static_cast(iters) * count); + + random::PhiloxRandom gen(0x12345); + + int val = 1; + for (int i = 0; i < iters; ++i) { + for (int j = 0; j < count; j += 4) { + /// each invocation of gen() returns 128-bit samples + auto samples = gen(); + + // use the result trivially so the compiler does not optimize it away + val ^= samples[0] ^ samples[1] ^ samples[2] ^ samples[3]; + } + } + + // A anchor point to make sure the compiler does not cut corners + CHECK(val) << val; +} +BENCHMARK(BM_PhiloxRandom); + +static void BM_StdMTRandom(int iters) { + // Fill 2M random numbers + int count = 2 << 20; + + testing::ItemsProcessed(static_cast(iters) * count); + + std::mt19937 gen(0x12345); + + int val = 1; + for (int i = 0; i < iters; ++i) { + for (int j = 0; j < count; ++j) { + /// each invocation of gen() returns 32-bit sample + uint32 sample = gen(); + + // use the result trivially so the compiler does not optimize it away + val ^= sample; + } + } + + // A anchor point to make sure the compiler does not cut corners + CHECK(val) << val; +} +BENCHMARK(BM_StdMTRandom); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/random_shuffle_op.cc b/tensorflow/core/kernels/random_shuffle_op.cc new file mode 100644 index 0000000000..b87f4e58a0 --- /dev/null +++ b/tensorflow/core/kernels/random_shuffle_op.cc @@ -0,0 +1,89 @@ +// See docs in ../ops/random_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/util/guarded_philox_random.h" + +namespace tensorflow { + +// TODO(irving): If performance is critical, generate output directly instead +// of an in-place shuffle using a pseudorandom permutation like +// +// https://github.com/otherlab/geode/blob/master/geode/random/permute.cpp +// +// This is probably also the right thing if we want a GPU version of shuffling. + +// We use our own version of std::random_shuffle to guarantee that exactly +// size - 1 samples are used. +template +static inline void RandomShuffle(Iter first, Iter last, Random& uniform) { + if (first == last) return; + const auto stop = last - 1; + for (auto i = first; i != stop; ++i) { + using std::iter_swap; + iter_swap(i, i + uniform(last - i)); + } +} + +template +class RandomShuffleOp : public OpKernel { + public: + explicit RandomShuffleOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, generator_.Init(context)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + + if (input.NumElements() <= 1 || input.dim_size(0) <= 1) { + // No shuffling is required, so copy input directly to output + context->set_output(0, input); + } else { + // Reserve enough random samples for shuffling + const int64 size = input.dim_size(0); + const int64 samples = size - 1; + auto local_gen = generator_.ReserveSamples32(samples); + random::SingleSampleAdapter single(&local_gen); + const auto uniform = [&single](uint32 n) { return single() % n; }; + + if (input.dims() == 1) { + // For 1D data, copy and then shuffle in place + context->set_output(0, tensor::DeepCopy(input)); + auto vec = context->mutable_output(0)->vec(); + RandomShuffle(vec.data(), vec.data() + size, uniform); + } else { + // For >= 2D, shuffle indices and then copy across + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + const auto input_mat = input.flat_outer_dims(); + auto output_mat = output->flat_outer_dims(); + std::vector permutation(size); + for (int i = 0; i < size; i++) { + permutation[i] = i; + } + RandomShuffle(permutation.begin(), permutation.end(), uniform); + for (int i = 0; i < size; i++) { + output_mat.template chip<0>(i) = + input_mat.template chip<0>(permutation[i]); + } + } + } + } + + private: + GuardedPhiloxRandom generator_; +}; + +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomShuffle").Device(DEVICE_CPU).TypeConstraint("T"), \ + RandomShuffleOp); +TF_CALL_ALL_TYPES(REGISTER) + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc new file mode 100644 index 0000000000..561ec76e53 --- /dev/null +++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc @@ -0,0 +1,740 @@ +// See docs in ../ops/data_flow_ops.cc. + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/queue_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +class RandomShuffleQueue : public QueueBase { + public: + RandomShuffleQueue(int32 capacity, int32 min_after_dequeue, int64 seed, + int64 seed2, const DataTypeVector& component_dtypes, + const std::vector& component_shapes, + const string& name); + Status Initialize(); // Must be called before any other method. + + // Implementations of QueueInterface methods -------------------------------- + + Status ValidateTuple(const Tuple& tuple) override; + Status ValidateManyTuple(const Tuple& tuple) override; + void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) override; + void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) override; + void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override; + void TryDequeueMany(int num_elements, OpKernelContext* ctx, + CallbackWithTuple callback) override; + void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, + DoneCallback callback) override; + Status MatchesNodeDef(const NodeDef& node_def) override; + + int32 size() override { + mutex_lock lock(mu_); + return queues_[0].size(); + } + + private: + enum Action { kEnqueue, kDequeue }; + + ~RandomShuffleQueue() override {} + + TensorShape ManyOutShape(int i, int batch_size) { + TensorShape shape({batch_size}); + shape.AppendShape(component_shapes_[i]); + return shape; + } + + // Helper for dequeuing a single random element from queues_. + void DequeueLocked(OpKernelContext* ctx, Tuple* tuple) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + void Cancel(Action action, CancellationToken token); + + // Helper for cancelling all pending Enqueue(Many) operations when + // Close is called with cancel_pending_enqueues. + void CloseAndCancel(); + + // Tries to enqueue/dequeue (or close) based on whatever is at the + // front of enqueue_attempts_/dequeue_attempts_. Appends to + // *finished the callback for any finished attempt (so it may be + // called once mu_ is released). Returns true if any progress was + // made. + struct CleanUp { + CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm) + : finished(f), to_deregister(ct), cm(cm) {} + DoneCallback finished; + CancellationToken to_deregister; + CancellationManager* cm; + }; + bool TryAttemptLocked(Action action, std::vector* clean_up) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Tries to make progress on the enqueues or dequeues at the front + // of the *_attempts_ queues. + void FlushUnlocked(); + + const int32 capacity_; + const int32 min_after_dequeue_; + const int64 original_seed_; + const int64 original_seed2_; + + mutex mu_; + typedef std::vector SubQueue; + std::vector queues_ GUARDED_BY(mu_); + bool closed_ GUARDED_BY(mu_); + random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); + random::SingleSampleAdapter generator_ GUARDED_BY(mu_); + + enum RunResult { kNoProgress, kProgress, kComplete }; + struct Attempt; + typedef std::function RunCallback; + struct Attempt { + int32 elements_requested; + DoneCallback done_callback; // must be run outside mu_ + OpKernelContext* context; + CancellationToken cancellation_token; + RunCallback run_callback; // must be run while holding mu_ + bool is_cancelled; + Tuple tuple; + + Attempt(int32 elements_requested, DoneCallback done_callback, + OpKernelContext* context, CancellationToken cancellation_token, + RunCallback run_callback) + : elements_requested(elements_requested), + done_callback(done_callback), + context(context), + cancellation_token(cancellation_token), + run_callback(run_callback), + is_cancelled(false) {} + }; + std::deque enqueue_attempts_ GUARDED_BY(mu_); + std::deque dequeue_attempts_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueue); +}; + +RandomShuffleQueue::RandomShuffleQueue( + int capacity, int min_after_dequeue, int64 seed, int64 seed2, + const DataTypeVector& component_dtypes, + const std::vector& component_shapes, const string& name) + : QueueBase(component_dtypes, component_shapes, name), + capacity_(capacity), + min_after_dequeue_(min_after_dequeue), + original_seed_(seed), + original_seed2_(seed2), + closed_(false), + generator_(&parent_generator_) { + if (seed == 0 && seed2 == 0) { + // If both seeds are unspecified, use completely random seeds. + seed = random::New64(); + seed2 = random::New64(); + } + parent_generator_ = random::PhiloxRandom(seed, seed2); +} + +Status RandomShuffleQueue::Initialize() { + if (component_dtypes_.empty()) { + return errors::InvalidArgument("Empty component types for queue ", name_); + } + if (!component_shapes_.empty() && + component_dtypes_.size() != component_shapes_.size()) { + return errors::InvalidArgument("Different number of component types (", + component_dtypes_.size(), ") vs. shapes (", + component_shapes_.size(), ")."); + } + + mutex_lock lock(mu_); + queues_.reserve(num_components()); + for (int i = 0; i < num_components(); ++i) { + queues_.push_back(SubQueue()); + queues_.back().reserve(min_after_dequeue_); + } + return Status::OK(); +} + +// TODO(mrry): If these checks become a bottleneck, find a way to +// reduce the number of times that they are called. +Status RandomShuffleQueue::ValidateTuple(const Tuple& tuple) { + TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); + if (specified_shapes()) { + for (size_t i = 0; i < tuple.size(); ++i) { + if (!tuple[i].shape().IsSameSize(component_shapes_[i])) { + return errors::InvalidArgument( + "Shape mismatch in tuple component ", i, ". Expected ", + component_shapes_[i].ShortDebugString(), ", got ", + tuple[i].shape().ShortDebugString()); + } + } + } + return Status::OK(); +} + +// TODO(mrry): If these checks become a bottleneck, find a way to +// reduce the number of times that they are called. +Status RandomShuffleQueue::ValidateManyTuple(const Tuple& tuple) { + TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); + const int64 batch_size = tuple[0].dim_size(0); + if (specified_shapes()) { + for (size_t i = 0; i < tuple.size(); ++i) { + // Expected shape is [batch_size] + component_shapes_[i] + const TensorShape expected_shape = ManyOutShape(i, batch_size); + if (!tuple[i].shape().IsSameSize(expected_shape)) { + return errors::InvalidArgument( + "Shape mismatch in tuple component ", i, ". Expected ", + expected_shape.ShortDebugString(), ", got ", + tuple[i].shape().ShortDebugString()); + } + } + } else { + for (size_t i = 1; i < tuple.size(); ++i) { + if (tuple[i].dim_size(0) != batch_size) { + return errors::InvalidArgument( + "All input tensors must have the same size in the 0th ", + "dimension. Component ", i, " has ", tuple[i].dim_size(0), + ", and should have ", batch_size); + } + } + } + return Status::OK(); +} + +void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) { + DCHECK_GT(queues_[0].size(), 0); + int64 index = generator_() % queues_[0].size(); + (*tuple).reserve(num_components()); + for (int i = 0; i < num_components(); ++i) { + (*tuple).push_back(*queues_[i][index].AccessTensor(ctx)); + queues_[i][index] = queues_[i].back(); + queues_[i].pop_back(); + } +} + +void RandomShuffleQueue::Cancel(Action action, CancellationToken token) { + DoneCallback callback = nullptr; + { + mutex_lock lock(mu_); + std::deque* attempts = + action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_; + + for (Attempt& attempt : *attempts) { + if (attempt.cancellation_token == token) { + attempt.is_cancelled = true; + if (action == kEnqueue) { + attempt.context->SetStatus( + errors::Cancelled("Enqueue operation was cancelled")); + } else { + attempt.context->SetStatus( + errors::Cancelled("Dequeue operation was cancelled")); + } + std::swap(callback, attempt.done_callback); + break; + } + } + } + if (callback) { + callback(); + FlushUnlocked(); + } +} + +void RandomShuffleQueue::CloseAndCancel() { + std::vector callbacks; + { + mutex_lock lock(mu_); + closed_ = true; + for (Attempt& attempt : enqueue_attempts_) { + attempt.is_cancelled = true; + attempt.context->SetStatus( + errors::Cancelled("Enqueue operation was cancelled")); + callbacks.emplace_back(std::move(attempt.done_callback)); + } + } + for (const DoneCallback& callback : callbacks) { + callback(); + } + FlushUnlocked(); +} + +bool RandomShuffleQueue::TryAttemptLocked( + Action action, std::vector* clean_up) { + std::deque* attempts = + action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_; + + bool progress = false; + bool done = false; + while (!done && !attempts->empty()) { + if (attempts->front().is_cancelled) { + if (action == kEnqueue) { + LOG(INFO) << "Skipping cancelled enqueue attempt"; + } else { + LOG(INFO) << "Skipping cancelled dequeue attempt"; + } + attempts->pop_front(); + } else { + Attempt* cur_attempt = &attempts->front(); + switch (cur_attempt->run_callback(cur_attempt)) { + case kNoProgress: + done = true; + break; + case kProgress: + done = true; + progress = true; + break; + case kComplete: + progress = true; + clean_up->emplace_back(std::move(cur_attempt->done_callback), + cur_attempt->cancellation_token, + cur_attempt->context->cancellation_manager()); + attempts->pop_front(); + break; + } + } + } + return progress; +} + +void RandomShuffleQueue::FlushUnlocked() { + std::vector clean_up; + Ref(); + { + mutex_lock lock(mu_); + bool changed; + do { + changed = TryAttemptLocked(kEnqueue, &clean_up); + changed = TryAttemptLocked(kDequeue, &clean_up) || changed; + } while (changed); + } + Unref(); + for (const auto& to_clean : clean_up) { + if (to_clean.to_deregister != CancellationManager::kInvalidToken) { + // NOTE(mrry): We can safely ignore the return value of + // DeregisterCallback because the mutex mu_ ensures that the + // cleanup action only executes once. + to_clean.cm->DeregisterCallback(to_clean.to_deregister); + } + to_clean.finished(); + } +} + +void RandomShuffleQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) { + CancellationManager* cm = ctx->cancellation_manager(); + CancellationToken token = cm->get_cancellation_token(); + bool already_cancelled; + { + mutex_lock l(mu_); + already_cancelled = !cm->RegisterCallback( + token, [this, token]() { Cancel(kEnqueue, token); }); + if (!already_cancelled) { + enqueue_attempts_.emplace_back( + 1, callback, ctx, token, + [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (closed_) { + attempt->context->SetStatus(errors::Aborted( + "RandomShuffleQueue '", name_, "' is closed.")); + return kComplete; + } + if (queues_[0].size() < static_cast(capacity_)) { + for (int i = 0; i < num_components(); ++i) { + queues_[i].push_back(PersistentTensor(tuple[i])); + } + return kComplete; + } else { + return kNoProgress; + } + }); + } + } + if (!already_cancelled) { + FlushUnlocked(); + } else { + ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled")); + callback(); + } +} + +void RandomShuffleQueue::TryEnqueueMany(const Tuple& tuple, + OpKernelContext* ctx, + DoneCallback callback) { + const int64 batch_size = tuple[0].dim_size(0); + if (batch_size == 0) { + callback(); + return; + } + + CancellationManager* cm = ctx->cancellation_manager(); + CancellationToken token = cm->get_cancellation_token(); + bool already_cancelled; + { + mutex_lock l(mu_); + already_cancelled = !cm->RegisterCallback( + token, [this, token]() { Cancel(kEnqueue, token); }); + if (!already_cancelled) { + enqueue_attempts_.emplace_back( + batch_size, callback, ctx, token, + [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (closed_) { + attempt->context->SetStatus(errors::Aborted( + "RandomShuffleQueue '", name_, "' is closed.")); + return kComplete; + } + RunResult result = kNoProgress; + while (queues_[0].size() < static_cast(capacity_)) { + result = kProgress; + const int index = + tuple[0].dim_size(0) - attempt->elements_requested; + for (int i = 0; i < num_components(); ++i) { + TensorShape element_shape(tuple[i].shape()); + element_shape.RemoveDim(0); + PersistentTensor element; + Tensor* element_access = nullptr; + attempt->context->allocate_persistent( + tuple[i].dtype(), element_shape, &element, &element_access); + attempt->context->SetStatus( + CopySliceToElement(tuple[i], element_access, index)); + if (!attempt->context->status().ok()) return kComplete; + queues_[i].push_back(element); + } + --attempt->elements_requested; + if (attempt->elements_requested == 0) { + return kComplete; + } + } + return result; + }); + } + } + if (!already_cancelled) { + FlushUnlocked(); + } else { + ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled")); + callback(); + } +} + +void RandomShuffleQueue::TryDequeue(OpKernelContext* ctx, + CallbackWithTuple callback) { + CancellationManager* cm = ctx->cancellation_manager(); + CancellationToken token = cm->get_cancellation_token(); + bool already_cancelled; + { + mutex_lock l(mu_); + already_cancelled = !cm->RegisterCallback( + token, [this, token]() { Cancel(kDequeue, token); }); + if (!already_cancelled) { + // TODO(josh11b): This makes two copies of callback, avoid this if possible. + dequeue_attempts_.emplace_back( + 1, [callback]() { callback(Tuple()); }, ctx, token, + [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int32 s = queues_[0].size(); + if (closed_ && s == 0) { + attempt->context->SetStatus(errors::OutOfRange( + "RandomShuffleQueue '", name_, "' is closed and has ", + "insufficient elements (requested ", 1, ", current size ", s, + ")")); + return kComplete; + } + if (!closed_) s -= min_after_dequeue_; + if (s > 0) { + Tuple tuple; + DequeueLocked(attempt->context, &tuple); + attempt->done_callback = [callback, tuple]() { callback(tuple); }; + return kComplete; + } else { + return kNoProgress; + } + }); + } + } + if (!already_cancelled) { + FlushUnlocked(); + } else { + ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled")); + callback(Tuple()); + } +} + +void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, + CallbackWithTuple callback) { + if (!specified_shapes()) { + ctx->SetStatus( + errors::InvalidArgument("RandomShuffleQueue's DequeueMany requires the " + "components to have specified shapes.")); + callback(Tuple()); + return; + } + if (num_elements == 0) { + Tuple tuple; + tuple.reserve(num_components()); + for (int i = 0; i < num_components(); ++i) { + // TODO(josh11b,misard): Switch to allocate_output(). Problem is + // this breaks the abstraction boundary since we don't *really* + // know if and how the Tensors in the tuple we pass to callback + // correspond to the outputs of *ctx. For example, the + // ReaderRead Op uses TryDequeue() to get a filename out of a + // queue that is used internally by the reader and is not + // associated with any output of the ReaderRead. + // mrry@ adds: + // Maybe we need to pass a std::function (or + // better signature) that calls the appropriate allocator + // function in addition to ctx? (Or support a shim Allocator + // that has an internal OpKernelContext*, and dispatches to the + // appropriate method?) + // misard@ adds: + // I don't see that a std::function would help. The problem is + // that at this point (allocation time) the system doesn't know + // what is going to happen to the element read out of the + // queue. As long as we keep the generality that TensorFlow Ops + // do their own dynamic allocation in arbitrary C++ code, we + // need to preserve robustness to allocating output Tensors with + // the 'wrong' attributes, and fixing up with a copy. The only + // improvement I can see here in the future would be to support + // an optimized case where the queue 'knows' what attributes to + // use, and plumbs them through here. + Tensor element; + ctx->allocate_temp(component_dtypes_[i], ManyOutShape(i, 0), &element); + tuple.emplace_back(element); + } + callback(tuple); + return; + } + + CancellationManager* cm = ctx->cancellation_manager(); + CancellationToken token = cm->get_cancellation_token(); + bool already_cancelled; + { + mutex_lock l(mu_); + already_cancelled = !cm->RegisterCallback( + token, [this, token]() { Cancel(kDequeue, token); }); + if (!already_cancelled) { + // TODO(josh11b): This makes two copies of callback, avoid this if possible. + dequeue_attempts_.emplace_back( + num_elements, [callback]() { callback(Tuple()); }, ctx, token, + [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int32 s = queues_[0].size(); + if (closed_ && s < attempt->elements_requested) { + attempt->context->SetStatus(errors::OutOfRange( + "RandomSuffleQueue '", name_, "' is closed and has ", + "insufficient elements (requested ", + attempt->elements_requested, ", current size ", s, ")")); + return kComplete; + } + + RunResult result = kNoProgress; + if (!closed_) s -= min_after_dequeue_; + for (; s > 0; --s) { + if (attempt->tuple.empty()) { + // Only allocate tuple when we have something to dequeue + // so we don't use exceessive memory when there are many + // blocked dequeue attempts waiting. + attempt->tuple.reserve(num_components()); + for (int i = 0; i < num_components(); ++i) { + const TensorShape shape = + ManyOutShape(i, attempt->elements_requested); + Tensor element; + attempt->context->allocate_temp(component_dtypes_[i], shape, + &element); + attempt->tuple.emplace_back(element); + } + } + result = kProgress; + Tuple tuple; + DequeueLocked(attempt->context, &tuple); + const int index = + attempt->tuple[0].dim_size(0) - attempt->elements_requested; + for (int i = 0; i < num_components(); ++i) { + attempt->context->SetStatus( + CopyElementToSlice(tuple[i], &attempt->tuple[i], index)); + if (!attempt->context->status().ok()) return kComplete; + } + tuple.clear(); + --attempt->elements_requested; + if (attempt->elements_requested == 0) { + tuple = attempt->tuple; + attempt->done_callback = [callback, tuple]() { + callback(tuple); + }; + return kComplete; + } + } + return result; + }); + } + } + if (!already_cancelled) { + FlushUnlocked(); + } else { + ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled")); + callback(Tuple()); + } +} + +void RandomShuffleQueue::Close(OpKernelContext* ctx, + bool cancel_pending_enqueues, + DoneCallback callback) { + if (cancel_pending_enqueues) { + CloseAndCancel(); + callback(); + } else { + { + mutex_lock lock(mu_); + enqueue_attempts_.emplace_back( + 0, callback, ctx, CancellationManager::kInvalidToken, + [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (closed_) { + attempt->context->SetStatus(errors::Aborted( + "RandomShuffleQueue '", name_, "' is already closed.")); + } else { + closed_ = true; + } + return kComplete; + }); + } + FlushUnlocked(); + } +} + +Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) { + TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "RandomShuffleQueue")); + TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); + + int32 min_after_dequeue = -1; + TF_RETURN_IF_ERROR( + GetNodeAttr(node_def, "min_after_dequeue", &min_after_dequeue)); + if (min_after_dequeue != min_after_dequeue_) { + return errors::InvalidArgument( + "Shared queue '", name_, "' has min_after_dequeue ", + min_after_dequeue_, " but requested min_after_dequeue was ", + min_after_dequeue, "."); + } + + int64 seed = -1; + int64 seed2 = -1; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed", &seed)); + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed2", &seed2)); + if ((seed != 0 || seed2 != 0) && + (seed != original_seed_ || seed2 != original_seed2_)) { + return errors::InvalidArgument( + "Shared queue '", name_, "' has random seeds (", original_seed_, ", ", + original_seed2_, ") but requested seeds are (", seed, ", ", seed2, + ")."); + } + + TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def)); + TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def)); + + return Status::OK(); +} + +typedef std::shared_ptr QueueInterfacePtr; + +// Defines a RandomShuffleQueueOp, which produces a Queue (specifically, one +// backed by RandomShuffleQueue) that persists across different graph +// executions, and sessions. Running this op produces a single-element +// tensor of handles to Queues in the corresponding device. +class RandomShuffleQueueOp : public OpKernel { + public: + explicit RandomShuffleQueueOp(OpKernelConstruction* context) + : OpKernel(context), queue_handle_set_(false) { + OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_)); + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_STRING, TensorShape({2}), + &queue_handle_, nullptr)); + if (capacity_ < 0) { + capacity_ = RandomShuffleQueue::kUnbounded; + } + OP_REQUIRES_OK(context, + context->GetAttr("min_after_dequeue", &min_after_dequeue_)); + OP_REQUIRES(context, min_after_dequeue_ >= 0, + errors::InvalidArgument("min_after_dequeue ", + min_after_dequeue_, " must be >= 0")); + OP_REQUIRES( + context, min_after_dequeue_ < capacity_, + errors::InvalidArgument("min_after_dequeue ", min_after_dequeue_, + " must be < capacity ", capacity_)); + OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_)); + OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_)); + + OP_REQUIRES_OK(context, + context->GetAttr("component_types", &component_types_)); + OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); + } + + ~RandomShuffleQueueOp() override { + // If the queue object was not shared, delete it. + if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) { + TF_CHECK_OK(cinfo_.resource_manager()->Delete( + cinfo_.container(), cinfo_.name())); + } + } + + void Compute(OpKernelContext* ctx) override { + mutex_lock l(mu_); + if (!queue_handle_set_) { + OP_REQUIRES_OK(ctx, SetQueueHandle(ctx)); + } + ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx)); + } + + private: + Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def())); + QueueInterface* queue; + auto creator = [this](QueueInterface** ret) { + auto* q = new RandomShuffleQueue(capacity_, min_after_dequeue_, seed_, + seed2_, component_types_, + component_shapes_, cinfo_.name()); + Status s = q->Initialize(); + if (s.ok()) { + *ret = q; + } else { + q->Unref(); + } + return s; + }; + TF_RETURN_IF_ERROR( + cinfo_.resource_manager()->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &queue, creator)); + core::ScopedUnref unref_me(queue); + // Verify that the shared queue is compatible with the requested arguments. + TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def())); + auto h = queue_handle_.AccessTensor(ctx)->flat(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + queue_handle_set_ = true; + return Status::OK(); + } + + int32 capacity_; + int32 min_after_dequeue_; + int64 seed_; + int64 seed2_; + DataTypeVector component_types_; + std::vector component_shapes_; + ContainerInfo cinfo_; + + mutex mu_; + PersistentTensor queue_handle_ GUARDED_BY(mu_); + bool queue_handle_set_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueueOp); +}; + +REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueue").Device(DEVICE_CPU), + RandomShuffleQueueOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/range_sampler.cc b/tensorflow/core/kernels/range_sampler.cc new file mode 100644 index 0000000000..a3f4e0b0cb --- /dev/null +++ b/tensorflow/core/kernels/range_sampler.cc @@ -0,0 +1,305 @@ +#include "tensorflow/core/kernels/range_sampler.h" + +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/io/inputbuffer.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +using gtl::ArraySlice; +using gtl::MutableArraySlice; + +RangeSampler::~RangeSampler() {} + +void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique, + gtl::MutableArraySlice batch) const { + SampleBatchGetExpectedCount( + rnd, unique, batch, gtl::MutableArraySlice(), + gtl::ArraySlice(), gtl::MutableArraySlice()); +} + +void RangeSampler::SampleBatchGetExpectedCount( + random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice batch, + gtl::MutableArraySlice batch_expected_count, + gtl::ArraySlice extras, + gtl::MutableArraySlice extras_expected_count) const { + SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count, + extras, extras_expected_count, + gtl::ArraySlice()); +} + +namespace { + +// Approximates the expected count of a value in the output of SampleBatch. +// +// If unique=false, then this is (Probability(value) * batch_size) +// +// We use batch_size and num_tries, where num_tries is the observed number of +// tries it took to get batch_size unique values. +// +// Assuming (falsely) that the nubmer of tries to get a batch of batch_size +// distinct values is _always_ num_tries, the probability that the value +// is in a batch is (1 - (1-p)^num_tries) +static float ExpectedCountHelper(float p, int batch_size, int num_tries) { + if (num_tries == batch_size) { + // This shortcut will always be taken if unique=false + return p * batch_size; + } + // numerically stable version of (1 - (1-p)^num_tries) + return -expm1(num_tries * log1p(-p)); +} + +} // namespace + +void RangeSampler::SampleBatchGetExpectedCountAvoid( + random::SimplePhilox* rnd, bool unique, MutableArraySlice batch, + MutableArraySlice batch_expected_count, ArraySlice extras, + MutableArraySlice extras_expected_count, + ArraySlice avoided_values) const { + const int batch_size = batch.size(); + int num_tries; + + if (unique) { + CHECK_LE(batch_size + avoided_values.size(), range_); + std::unordered_set used(batch_size); + used.insert(avoided_values.begin(), avoided_values.end()); + int num_picked = 0; + num_tries = 0; + while (num_picked < batch_size) { + num_tries++; + CHECK_LT(num_tries, kint32max); + int64 value = Sample(rnd); + if (gtl::InsertIfNotPresent(&used, value)) { + batch[num_picked++] = value; + } + } + } else { + CHECK_EQ(avoided_values.size(), 0) + << "avoided_values only supported with unique=true"; + for (int i = 0; i < batch_size; i++) { + batch[i] = Sample(rnd); + } + num_tries = batch_size; + } + // Compute the expected counts of the batch and the extra values + if (batch_expected_count.size() > 0) { + CHECK_EQ(batch_size, batch_expected_count.size()); + for (int i = 0; i < batch_size; i++) { + batch_expected_count[i] = + ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries); + } + } + CHECK_EQ(extras.size(), extras_expected_count.size()); + for (size_t i = 0; i < extras.size(); i++) { + extras_expected_count[i] = + ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries); + } +} + +AllSampler::AllSampler(int64 range) + : RangeSampler(range), inv_range_(1.0 / range) {} + +void AllSampler::SampleBatchGetExpectedCountAvoid( + random::SimplePhilox* rnd, bool unique, MutableArraySlice batch, + MutableArraySlice batch_expected_count, ArraySlice extras, + MutableArraySlice extras_expected_count, + ArraySlice avoided_values) const { + const int batch_size = batch.size(); + CHECK_EQ(range_, batch_size); + for (int i = 0; i < batch_size; i++) { + batch[i] = i; + } + if (batch_expected_count.size() > 0) { + CHECK_EQ(batch_size, batch_expected_count.size()); + for (int i = 0; i < batch_size; i++) { + batch_expected_count[i] = 1; + } + } + CHECK_EQ(0, avoided_values.size()); + CHECK_EQ(extras.size(), extras_expected_count.size()); + for (size_t i = 0; i < extras.size(); i++) { + extras_expected_count[i] = 1; + } +} + +UniformSampler::UniformSampler(int64 range) + : RangeSampler(range), inv_range_(1.0 / range) {} + +int64 UniformSampler::Sample(random::SimplePhilox* rnd) const { + return rnd->Uniform64(range_); +} + +float UniformSampler::Probability(int64 value) const { return inv_range_; } + +LogUniformSampler::LogUniformSampler(int64 range) + : RangeSampler(range), log_range_(log(range + 1)) {} + +int64 LogUniformSampler::Sample(random::SimplePhilox* rnd) const { + const int64 value = + static_cast(exp(rnd->RandDouble() * log_range_)) - 1; + CHECK_GE(value, 0); + // Mathematically, value should be <= range_, but might not be due to some + // floating point roundoff, so we mod by range_. + return value % range_; +} + +float LogUniformSampler::Probability(int64 value) const { + // value is returned iff the call to UniformDouble(log_range_) in the + // Sample() function returns a value between log(value + 1) + // and log(value + 2). The probability of this is: + // (log(value + 2) - log(value + 1)) / log_range + // To avoid two calls to log(), we compute this as follows: + return (log((value + 2.0) / (value + 1.0))) / log_range_; +} + +ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64 range) + : RangeSampler(range), picker_(range) { + CHECK_LT(range, kint32max); +} + +int64 ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const { + return picker_.Pick(rnd); +} + +float ThreadUnsafeUnigramSampler::Probability(int64 value) const { + return static_cast(picker_.get_weight(value)) / picker_.total_weight(); +} + +void ThreadUnsafeUnigramSampler::Update(ArraySlice values) { + int num_updates = std::min(static_cast(values.size()), + kint32max - picker_.total_weight()); + for (int i = 0; i < num_updates; i++) { + const int64 value = values[i]; + picker_.set_weight(value, picker_.get_weight(value) + 1); + } +} + +// Thread-safe unigram sampler +UnigramSampler::UnigramSampler(int64 range) + : RangeSampler(range), unsafe_sampler_(range) { + CHECK_LT(range, kint32max); +} + +int64 UnigramSampler::Sample(random::SimplePhilox* rnd) const { + mutex_lock lock(mu_); // could use reader lock + return unsafe_sampler_.Sample(rnd); +} + +float UnigramSampler::Probability(int64 value) const { + mutex_lock lock(mu_); // could use reader lock + return unsafe_sampler_.Probability(value); +} + +// Overriding at a high level results in far fewer lock aquisitions. +void UnigramSampler::SampleBatchGetExpectedCountAvoid( + random::SimplePhilox* rnd, bool unique, MutableArraySlice batch, + MutableArraySlice batch_expected_count, ArraySlice extras, + MutableArraySlice extras_expected_count, + ArraySlice avoided_values) const { + mutex_lock lock(mu_); // could use reader lock + unsafe_sampler_.SampleBatchGetExpectedCountAvoid( + rnd, unique, batch, batch_expected_count, extras, extras_expected_count, + avoided_values); +} + +void UnigramSampler::Update(ArraySlice values) { + mutex_lock lock(mu_); + unsafe_sampler_.Update(values); +} + +FixedUnigramSampler::FixedUnigramSampler(Env* env, int64 range, + const string& vocab_file, + float distortion, + int32 num_reserved_ids, + int32 num_shards, int32 shard) + : RangeSampler(range), + total_weight_(0.0), + num_shards_(num_shards), + shard_(shard) { + FillReservedIds(num_reserved_ids); + // TODO(vanhoucke): make this non-crashing. + TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion)); + CHECK_EQ(range, weights_.size()); + dist_sampler_.reset(new random::DistributionSampler(weights_)); +} + +FixedUnigramSampler::FixedUnigramSampler(int64 range, + const std::vector& unigrams, + float distortion, + int32 num_reserved_ids, + int32 num_shards, int32 shard) + : RangeSampler(range), + total_weight_(0.0), + num_shards_(num_shards), + shard_(shard) { + FillReservedIds(num_reserved_ids); + LoadFromUnigrams(unigrams, distortion); + // TODO(vanhoucke): make this non-crashing. + CHECK_EQ(range, weights_.size()); + dist_sampler_.reset(new random::DistributionSampler(weights_)); +} + +float FixedUnigramSampler::Probability(int64 value) const { + return weights_.at(value) / total_weight_; +} + +int64 FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const { + return dist_sampler_->Sample(rnd); +} + +void FixedUnigramSampler::FillReservedIds(int32 num_reserved_ids) { + for (int32 word_id = 0; word_id < num_reserved_ids; ++word_id) { + if (word_id % num_shards_ == shard_) weights_.push_back(0.0); + } +} + +Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file, + float distortion) { + RandomAccessFile* file; + TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file)); + io::InputBuffer in(file, 262144 /*bytes*/); + string line; + int32 word_id = weights_.size(); + while (in.ReadLine(&line).ok()) { + // The vocabulary file should be in csv like format, with the last + // field the weight associated with the word. + std::vector cols = str_util::Split(line, ','); + if (cols.size() == 0) continue; + // Skip entries that do not belong to this shard. + if (word_id % num_shards_ == shard_) { + float w = 0.0; + if (!strings::safe_strtof(cols.at(cols.size() - 1).c_str(), &w)) { + return errors::InvalidArgument("Wrong vocabulary format at line: ", + line); + } + w = pow(w, distortion); + total_weight_ += w; + weights_.push_back(w); + } + ++word_id; + } + return Status::OK(); +} + +void FixedUnigramSampler::LoadFromUnigrams(const std::vector& unigrams, + float distortion) { + int32 word_id = weights_.size(); + for (float w : unigrams) { + // Skip entries that do not belong to this shard. + if (word_id % num_shards_ == shard_) { + w = pow(w, distortion); + total_weight_ += w; + weights_.push_back(w); + } + ++word_id; + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/range_sampler.h b/tensorflow/core/kernels/range_sampler.h new file mode 100644 index 0000000000..18364c2c03 --- /dev/null +++ b/tensorflow/core/kernels/range_sampler.h @@ -0,0 +1,237 @@ +#ifndef TENSORFLOW_KERNELS_RANGE_SAMPLER_H_ +#define TENSORFLOW_KERNELS_RANGE_SAMPLER_H_ + +#include + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/distribution_sampler.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/lib/random/weighted_picker.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class Env; + +// Abstract subclass for sampling from the set of non-negative integers +// [0, range) +class RangeSampler { + public: + explicit RangeSampler(int range) : range_(range) { CHECK_GT(range_, 0); } + virtual ~RangeSampler(); + + // Sample a single value + virtual int64 Sample(random::SimplePhilox* rnd) const = 0; + + // The probability that a single call to Sample() returns the given value. + // Assumes that value is in [0, range). No range checking is done. + virtual float Probability(int64 value) const = 0; + + // Fill "batch" with samples from the distribution. + // If unique=true, then we re-pick each element until we get a + // value distinct from all previously picked values in the batch. + void SampleBatch(random::SimplePhilox* rnd, bool unique, + gtl::MutableArraySlice batch) const; + + // Fill "batch" with samples from the distribution, and report + // "expected counts". + // + // The "expected count" of a value is an estimate of the expected + // number of occurrences of the value in the batch returned by a + // call to this function with the given parameters. If unique=true, + // the expected count is an inclusion probability. For details on + // this estimation, see the comment to "ExpectedCountHelper" in the + // .cc file. + // + // Expected counts for the elements of the returned "batch" are reported + // in the aligned array "batch_expected_count". + // + // The user can optionally provide "extras", containg values in the range. + // The expected counts for the extras are reported in the aligned array + // "extras_expected_count". + // + // "batch_expected_count" must have size equal to 0 or to the size of "batch". + // "extras" and "extras_expected_count" must have equal size. + void SampleBatchGetExpectedCount( + random::SimplePhilox* rnd, bool unique, + gtl::MutableArraySlice batch, + gtl::MutableArraySlice batch_expected_count, + gtl::ArraySlice extras, + gtl::MutableArraySlice extras_expected_count) const; + + // Same as SampleBatchGetExpectedCount (see above), but with avoided values. + // We repick to avoid all of the values in "avoided_values". + // "avoided_values" is only supported with unique=true. If + // unique=false, then avoided_values must be empty. + virtual void SampleBatchGetExpectedCountAvoid( + random::SimplePhilox* rnd, bool unique, + gtl::MutableArraySlice batch, + gtl::MutableArraySlice batch_expected_count, + gtl::ArraySlice extras, + gtl::MutableArraySlice extras_expected_count, + gtl::ArraySlice avoided_values) const; + + // Does this sampler need to be updated with values, e.g. UnigramSampler + virtual bool NeedsUpdates() const { return false; } + + // Updates the underlying distribution + virtual void Update(gtl::ArraySlice values) { + LOG(FATAL) << "Update not supported for this sampler type."; + } + + int64 range() { return range_; } + + protected: + const int64 range_; +}; + +// An AllSampler only samples batches of size equal to range. +// It returns the entire range. +// It cannot sample single values. +class AllSampler : public RangeSampler { + public: + explicit AllSampler(int64 range); + + ~AllSampler() override {} + + int64 Sample(random::SimplePhilox* rnd) const override { + LOG(FATAL) << "Should not be called"; + } + + float Probability(int64 value) const override { + LOG(FATAL) << "Should not be called"; + } + + void SampleBatchGetExpectedCountAvoid( + random::SimplePhilox* rnd, bool unique, + gtl::MutableArraySlice batch, + gtl::MutableArraySlice batch_expected_count, + gtl::ArraySlice extras, + gtl::MutableArraySlice extras_expected_count, + gtl::ArraySlice avoided_values) const override; + + private: + const float inv_range_; +}; + +class UniformSampler : public RangeSampler { + public: + explicit UniformSampler(int64 range); + + ~UniformSampler() override {} + + int64 Sample(random::SimplePhilox* rnd) const override; + + float Probability(int64 value) const override; + + private: + const float inv_range_; +}; + +class LogUniformSampler : public RangeSampler { + public: + explicit LogUniformSampler(int64 range); + + ~LogUniformSampler() override {} + + int64 Sample(random::SimplePhilox* rnd) const override; + + float Probability(int64 value) const override; + + private: + const double log_range_; +}; + +// Thread-unsafe unigram sampler +class ThreadUnsafeUnigramSampler : public RangeSampler { + public: + explicit ThreadUnsafeUnigramSampler(int64 range); + ~ThreadUnsafeUnigramSampler() override {} + + int64 Sample(random::SimplePhilox* rnd) const override; + + float Probability(int64 value) const override; + + bool NeedsUpdates() const override { return true; } + void Update(gtl::ArraySlice values) override; + + private: + random::WeightedPicker picker_; +}; + +// Thread-safe unigram sampler +class UnigramSampler : public RangeSampler { + public: + explicit UnigramSampler(int64 range); + ~UnigramSampler() override {} + + int64 Sample(random::SimplePhilox* rnd) const override; + + float Probability(int64 value) const override; + + // Overriding at a high level results in far fewer lock aquisitions. + void SampleBatchGetExpectedCountAvoid( + random::SimplePhilox* rnd, bool unique, + gtl::MutableArraySlice batch, + gtl::MutableArraySlice batch_expected_count, + gtl::ArraySlice extras, + gtl::MutableArraySlice extras_expected_count, + gtl::ArraySlice avoided_values) const override; + + bool NeedsUpdates() const override { return true; } + void Update(gtl::ArraySlice values) override; + + private: + ThreadUnsafeUnigramSampler unsafe_sampler_ GUARDED_BY(mu_); + mutable mutex mu_; +}; + +// A unigram sampler that uses a fixed unigram distribution read from a +// file or passed in as an in-memory array instead of building up the +// distribution from data on the fly. There is also an option to skew the +// distribution by applying a distortion power to the weights. +class FixedUnigramSampler : public RangeSampler { + public: + // The vocab_file is assumed to be a CSV, with the last entry of each row a + // value representing the counts or probabilities for the corresponding ID. + FixedUnigramSampler(Env* env, int64 range, const string& vocab_file, + float distortion, int32 num_reserved_ids, + int32 num_shards, int32 shard); + + FixedUnigramSampler(int64 range, const std::vector& unigrams, + float distortion, int32 num_reserved_ids, + int32 num_shards, int32 shard); + + float Probability(int64 value) const override; + + int64 Sample(random::SimplePhilox* rnd) const override; + + private: + // Underlying distribution sampler. + std::unique_ptr dist_sampler_; + // Weights for individual samples. The probability of a sample i is defined + // as weights_.at(i) / total_weight_. + std::vector weights_; + // The total weights of all samples. + float total_weight_; + // Sharding information of the sampler. The whole vocabulary is sharded + // into num_shards_ smaller ranges and each sampler is responsible for one + // such smaller range, identified by the shard number. + int32 num_shards_; + int32 shard_; + + // Fill the sampler with the appropriate number of reserved IDs. + void FillReservedIds(int32 num_reserved_ids); + // Load IDs to sample from a CSV file. It is assumed that the last item of + // each row contains a count or probability for the corresponding ID. + Status LoadFromFile(Env* env, const string& vocab_file, float distortion); + // Load from an in-memory array. + void LoadFromUnigrams(const std::vector& unigrams, float distortion); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_RANGE_SAMPLER_H_ diff --git a/tensorflow/core/kernels/range_sampler_test.cc b/tensorflow/core/kernels/range_sampler_test.cc new file mode 100644 index 0000000000..72c39009e4 --- /dev/null +++ b/tensorflow/core/kernels/range_sampler_test.cc @@ -0,0 +1,320 @@ +#include + +#include +#include "tensorflow/core/kernels/range_sampler.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace { + +using gtl::ArraySlice; +using gtl::MutableArraySlice; + +class RangeSamplerTest : public ::testing::Test { + protected: + void CheckProbabilitiesSumToOne() { + double sum = 0; + for (int i = 0; i < sampler_->range(); i++) { + sum += sampler_->Probability(i); + } + EXPECT_NEAR(sum, 1.0, 1e-4); + } + void CheckHistogram(int num_samples, float tolerance) { + const int range = sampler_->range(); + std::vector h(range); + std::vector a(num_samples); + // Using a fixed random seed to make the test deterministic. + random::PhiloxRandom philox(123, 17); + random::SimplePhilox rnd(&philox); + sampler_->SampleBatch(&rnd, false, &a); + for (int i = 0; i < num_samples; i++) { + int64 val = a[i]; + ASSERT_GE(val, 0); + ASSERT_LT(val, range); + h[val]++; + } + for (int val = 0; val < range; val++) { + EXPECT_NEAR((h[val] + 0.0) / num_samples, sampler_->Probability(val), + tolerance); + } + } + void Update1() { + // Add the value 3 ten times. + std::vector a(10); + for (int i = 0; i < 10; i++) { + a[i] = 3; + } + sampler_->Update(a); + } + void Update2() { + // Add the value n n times. + int64 a[10]; + for (int i = 0; i < 10; i++) { + a[i] = i; + } + for (int64 i = 1; i < 10; i++) { + sampler_->Update(ArraySlice(a + i, 10 - i)); + } + } + std::unique_ptr sampler_; +}; + +TEST_F(RangeSamplerTest, UniformProbabilities) { + sampler_.reset(new UniformSampler(10)); + for (int i = 0; i < 10; i++) { + CHECK_EQ(sampler_->Probability(i), sampler_->Probability(0)); + } +} + +TEST_F(RangeSamplerTest, UniformChecksum) { + sampler_.reset(new UniformSampler(10)); + CheckProbabilitiesSumToOne(); +} + +TEST_F(RangeSamplerTest, UniformHistogram) { + sampler_.reset(new UniformSampler(10)); + CheckHistogram(1000, 0.05); +} + +TEST_F(RangeSamplerTest, LogUniformProbabilities) { + int range = 1000000; + sampler_.reset(new LogUniformSampler(range)); + for (int i = 100; i < range; i *= 2) { + float ratio = sampler_->Probability(i) / sampler_->Probability(i / 2); + EXPECT_NEAR(ratio, 0.5, 0.1); + } +} + +TEST_F(RangeSamplerTest, LogUniformChecksum) { + sampler_.reset(new LogUniformSampler(10)); + CheckProbabilitiesSumToOne(); +} + +TEST_F(RangeSamplerTest, LogUniformHistogram) { + sampler_.reset(new LogUniformSampler(10)); + CheckHistogram(1000, 0.05); +} + +TEST_F(RangeSamplerTest, UnigramProbabilities1) { + sampler_.reset(new UnigramSampler(10)); + Update1(); + EXPECT_NEAR(sampler_->Probability(3), 0.55, 1e-4); + for (int i = 0; i < 10; i++) { + if (i != 3) { + ASSERT_NEAR(sampler_->Probability(i), 0.05, 1e-4); + } + } +} +TEST_F(RangeSamplerTest, UnigramProbabilities2) { + sampler_.reset(new UnigramSampler(10)); + Update2(); + for (int i = 0; i < 10; i++) { + ASSERT_NEAR(sampler_->Probability(i), (i + 1) / 55.0, 1e-4); + } +} +TEST_F(RangeSamplerTest, UnigramChecksum) { + sampler_.reset(new UnigramSampler(10)); + Update1(); + CheckProbabilitiesSumToOne(); +} + +TEST_F(RangeSamplerTest, UnigramHistogram) { + sampler_.reset(new UnigramSampler(10)); + Update1(); + CheckHistogram(1000, 0.05); +} + +static const char kVocabContent[] = + "w1,1\n" + "w2,2\n" + "w3,4\n" + "w4,8\n" + "w5,16\n" + "w6,32\n" + "w7,64\n" + "w8,128\n" + "w9,256"; +TEST_F(RangeSamplerTest, FixedUnigramProbabilities) { + Env* env = Env::Default(); + string fname = io::JoinPath(testing::TmpDir(), "vocab_file"); + TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent)); + sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0)); + // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05 + for (int i = 0; i < 9; i++) { + ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4); + } +} +TEST_F(RangeSamplerTest, FixedUnigramChecksum) { + Env* env = Env::Default(); + string fname = io::JoinPath(testing::TmpDir(), "vocab_file"); + TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent)); + sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0)); + CheckProbabilitiesSumToOne(); +} + +TEST_F(RangeSamplerTest, FixedUnigramHistogram) { + Env* env = Env::Default(); + string fname = io::JoinPath(testing::TmpDir(), "vocab_file"); + TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent)); + sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0)); + CheckHistogram(1000, 0.05); +} +TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1) { + Env* env = Env::Default(); + string fname = io::JoinPath(testing::TmpDir(), "vocab_file"); + TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent)); + sampler_.reset(new FixedUnigramSampler(env, 10, fname, 0.8, 1, 1, 0)); + ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4); + // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05 + for (int i = 1; i < 10; i++) { + ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4); + } +} +TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2) { + Env* env = Env::Default(); + string fname = io::JoinPath(testing::TmpDir(), "vocab_file"); + TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent)); + sampler_.reset(new FixedUnigramSampler(env, 11, fname, 0.8, 2, 1, 0)); + ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4); + ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4); + // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05 + for (int i = 2; i < 11; i++) { + ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4); + } +} +TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesFromVector) { + std::vector weights = {1, 2, 4, 8, 16, 32, 64, 128, 256}; + sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0)); + // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05 + for (int i = 0; i < 9; i++) { + ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4); + } +} +TEST_F(RangeSamplerTest, FixedUnigramChecksumFromVector) { + std::vector weights = {1, 2, 4, 8, 16, 32, 64, 128, 256}; + sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0)); + CheckProbabilitiesSumToOne(); +} +TEST_F(RangeSamplerTest, FixedUnigramHistogramFromVector) { + std::vector weights = {1, 2, 4, 8, 16, 32, 64, 128, 256}; + sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0)); + CheckHistogram(1000, 0.05); +} +TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1FromVector) { + std::vector weights = {1, 2, 4, 8, 16, 32, 64, 128, 256}; + sampler_.reset(new FixedUnigramSampler(10, weights, 0.8, 1, 1, 0)); + ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4); + // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05 + for (int i = 1; i < 10; i++) { + ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4); + } +} +TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2FromVector) { + std::vector weights = {1, 2, 4, 8, 16, 32, 64, 128, 256}; + sampler_.reset(new FixedUnigramSampler(11, weights, 0.8, 2, 1, 0)); + ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4); + ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4); + // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05 + for (int i = 2; i < 11; i++) { + ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4); + } +} + +// AllSampler cannot call Sample or Probability directly. +// We will test SampleBatchGetExpectedCount instead. +TEST_F(RangeSamplerTest, All) { + int batch_size = 10; + sampler_.reset(new AllSampler(10)); + std::vector batch(batch_size); + std::vector batch_expected(batch_size); + std::vector extras(2); + std::vector extras_expected(2); + extras[0] = 0; + extras[1] = batch_size - 1; + sampler_->SampleBatchGetExpectedCount(nullptr, // no random numbers needed + false, &batch, &batch_expected, extras, + &extras_expected); + for (int i = 0; i < batch_size; i++) { + EXPECT_EQ(i, batch[i]); + EXPECT_EQ(1, batch_expected[i]); + } + EXPECT_EQ(1, extras_expected[0]); + EXPECT_EQ(1, extras_expected[1]); +} + +TEST_F(RangeSamplerTest, Unique) { + // We sample num_batches batches, each without replacement. + // + // We check that the returned expected counts roughly agree with each other + // and with the average observed frequencies over the set of batches. + random::PhiloxRandom philox(123, 17); + random::SimplePhilox rnd(&philox); + const int range = 100; + const int batch_size = 50; + const int num_batches = 100; + sampler_.reset(new LogUniformSampler(range)); + std::vector histogram(range); + std::vector batch(batch_size); + std::vector all_values(range); + for (int i = 0; i < range; i++) { + all_values[i] = i; + } + std::vector expected(range); + + // Sample one batch and get the expected counts of all values + sampler_->SampleBatchGetExpectedCount( + &rnd, true, &batch, MutableArraySlice(), all_values, &expected); + // Check that all elements are unique + std::set s(batch.begin(), batch.end()); + CHECK_EQ(batch_size, s.size()); + + for (int trial = 0; trial < num_batches; trial++) { + std::vector trial_expected(range); + sampler_->SampleBatchGetExpectedCount(&rnd, true, &batch, + MutableArraySlice(), + all_values, &trial_expected); + for (int i = 0; i < range; i++) { + EXPECT_NEAR(expected[i], trial_expected[i], expected[i] * 0.5); + } + for (int i = 0; i < batch_size; i++) { + histogram[batch[i]]++; + } + } + for (int i = 0; i < range; i++) { + // Check that the computed expected count agrees with the average observed + // count. + const float average_count = static_cast(histogram[i]) / num_batches; + EXPECT_NEAR(expected[i], average_count, 0.2); + } +} + +TEST_F(RangeSamplerTest, Avoid) { + random::PhiloxRandom philox(123, 17); + random::SimplePhilox rnd(&philox); + sampler_.reset(new LogUniformSampler(100)); + std::vector avoided(2); + avoided[0] = 17; + avoided[1] = 23; + std::vector batch(98); + + // We expect to pick all elements of [0, 100) except the avoided two. + sampler_->SampleBatchGetExpectedCountAvoid( + &rnd, true, &batch, MutableArraySlice(), ArraySlice(), + MutableArraySlice(), avoided); + + int sum = 0; + for (auto val : batch) { + sum += val; + } + const int expected_sum = 100 * 99 / 2 - avoided[0] - avoided[1]; + EXPECT_EQ(expected_sum, sum); +} + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reader_base.cc b/tensorflow/core/kernels/reader_base.cc new file mode 100644 index 0000000000..06211efb38 --- /dev/null +++ b/tensorflow/core/kernels/reader_base.cc @@ -0,0 +1,156 @@ +#include "tensorflow/core/kernels/reader_base.h" + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +// ReaderBase ------------------------------------------------------ + +ReaderBase::ReaderBase(const string& name) : name_(name) {} + +int64 ReaderBase::NumRecordsProduced() { + mutex_lock lock(mu_); + return num_records_produced_; +} + +int64 ReaderBase::NumWorkUnitsCompleted() { + mutex_lock lock(mu_); + return work_finished_; +} + +Status ReaderBase::Reset() { + mutex_lock lock(mu_); + return ResetLocked(); +} + +Status ReaderBase::ResetLocked() { + work_started_ = 0; + work_finished_ = 0; + num_records_produced_ = 0; + work_.clear(); + return Status::OK(); +} + +Status ReaderBase::SerializeState(string* state) { + mutex_lock lock(mu_); + return SerializeStateLocked(state); +} + +Status ReaderBase::SerializeStateLocked(string* state) { + return errors::Unimplemented("Reader SerializeState"); +} + +Status ReaderBase::RestoreState(const string& state) { + mutex_lock lock(mu_); + Status status = RestoreStateLocked(state); + if (!status.ok()) { + ResetLocked(); + } + return status; +} + +Status ReaderBase::RestoreStateLocked(const string& state) { + return errors::Unimplemented("Reader RestoreState"); +} + +void ReaderBase::Read(QueueInterface* queue, string* key, string* value, + OpKernelContext* context) { + mutex_lock lock(mu_); + while (true) { + if (!work_in_progress()) { + GetNextWorkLocked(queue, context); + if (!context->status().ok()) return; + } + + bool produced = false; + bool at_end = false; + Status status = ReadLocked(key, value, &produced, &at_end); + + if (!at_end && status.ok() && !produced) { + status = errors::Internal( + "ReadLocked() for ", name(), + " must set *at_end=true, *produced=true, or return an error."); + } + if (!status.ok() && produced) { + status = errors::Internal("ReadLocked() for ", name(), + " set *produced=true *and* returned an error: ", + status.ToString()); + } + if (status.ok() && at_end) { + status = OnWorkFinishedLocked(); + work_finished_ = work_started_; + } + if (!status.ok()) { + context->SetStatus(status); + return; + } + if (produced) { + ++num_records_produced_; + return; + } + } +} + +void ReaderBase::GetNextWorkLocked(QueueInterface* queue, + OpKernelContext* context) { + Notification n; + queue->TryDequeue( + context, [this, context, &n](const QueueInterface::Tuple& tuple) { + if (context->status().ok()) { + if (tuple.size() != 1) { + context->SetStatus( + errors::InvalidArgument("Expected single component queue")); + } else if (tuple[0].dtype() != DT_STRING) { + context->SetStatus(errors::InvalidArgument( + "Expected queue with single string component")); + } else if (tuple[0].NumElements() != 1) { + context->SetStatus(errors::InvalidArgument( + "Expected to dequeue a one-element string tensor")); + } else { + work_ = tuple[0].flat()(0); + ++work_started_; + Status status = OnWorkStartedLocked(); + if (!status.ok()) { + context->SetStatus(status); + --work_started_; + } + } + } + n.Notify(); + }); + n.WaitForNotification(); +} + +void ReaderBase::SaveBaseState(ReaderBaseState* state) const { + state->Clear(); + state->set_work_started(work_started_); + state->set_work_finished(work_finished_); + state->set_num_records_produced(num_records_produced_); + state->set_current_work(work_); +} + +Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) { + work_started_ = state.work_started(); + work_finished_ = state.work_finished(); + num_records_produced_ = state.num_records_produced(); + work_ = state.current_work(); + if (work_started_ < 0 || work_finished_ < 0 || num_records_produced_ < 0) { + return errors::InvalidArgument( + "Unexpected negative value when restoring in ", name(), ": ", + state.ShortDebugString()); + } + if (work_started_ > work_finished_) { + return errors::InvalidArgument( + "Inconsistent work started vs. finished when restoring in ", name(), + ": ", state.ShortDebugString()); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reader_base.h b/tensorflow/core/kernels/reader_base.h new file mode 100644 index 0000000000..d344300388 --- /dev/null +++ b/tensorflow/core/kernels/reader_base.h @@ -0,0 +1,107 @@ +#ifndef TENSORFLOW_KERNELS_READER_BASE_H_ +#define TENSORFLOW_KERNELS_READER_BASE_H_ + +#include +#include +#include +#include "tensorflow/core/framework/queue_interface.h" +#include "tensorflow/core/framework/reader_interface.h" +#include "tensorflow/core/kernels/reader_base.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +// Default implementation of ReaderInterface. +class ReaderBase : public ReaderInterface { + public: + // name: For use in error messages, should mention both the name of + // the op and the node. + explicit ReaderBase(const string& name); + + // Note that methods with names ending in "Locked" are called while + // the ReaderBase's mutex is held. + + // Implement this function in descendants ----------------------------------- + + // Produce the next key/value pair from the current work item. + // This is called "Locked" since it is executed under a mutex + // that serializes all Reader calls. + // Usage: + // a) If a record was successfully produced, set *produced = true, + // and fill in *key and *value. + // b) If no more records will be produced for this work item, set + // *at_end = true. + // c) If a record was produced, but no more will be produced, you + // may either do both (a) and (b), or do (a) in this call and do (b) in + // the next call to ReadLocked(). + // d) If there was an error producing (e.g. an error reading the file, + // data corruption), return a non-OK() status. ReadLocked may be + // called again if the user reruns this part of the graph. + virtual Status ReadLocked(string* key, string* value, bool* produced, + bool* at_end) = 0; + + // Descendants may optionally implement these ------------------------------- + + // Called when work starts / finishes. + virtual Status OnWorkStartedLocked() { return Status::OK(); } + virtual Status OnWorkFinishedLocked() { return Status::OK(); } + + // Called to reset the Reader to a newly constructed state. + virtual Status ResetLocked(); + + // Default implementation generates an Unimplemented error. + // See the protected helper methods below. + virtual Status SerializeStateLocked(string* state); + virtual Status RestoreStateLocked(const string& state); + + // Accessors ---------------------------------------------------------------- + + // Always true during a call to ReadLocked(). + bool work_in_progress() const { return work_finished_ < work_started_; } + + // Returns the name of the current work item (valid if + // work_in_progress() returns true). May change between calls to + // ReadLocked(). + const string& current_work() const { return work_; } + + // What was passed to the constructor. + const string& name() const { return name_; } + + protected: + // For descendants wishing to implement serialize & restore state. + + // Writes ReaderBase state to *state. + void SaveBaseState(ReaderBaseState* state) const; + + // Restores ReaderBase state from state. Assumes state was filled + // using SaveBaseState() above. + Status RestoreBaseState(const ReaderBaseState& state); + + private: + // Implementations of ReaderInterface methods. These ensure thread-safety + // and call the methods above to do the work. + void Read(QueueInterface* queue, string* key, string* value, + OpKernelContext* context) override; + Status Reset() override; + int64 NumRecordsProduced() override; + int64 NumWorkUnitsCompleted() override; + Status SerializeState(string* state) override; + Status RestoreState(const string& state) override; + + // For implementing Read(). Dequeues the next work item from + // *queue, and if successful updates work_, work_started_ + // (establishing work_in_progress() == true) and calls + // OnWorkStartedLocked(). May block. + void GetNextWorkLocked(QueueInterface* queue, OpKernelContext* context); + + mutable mutex mu_; + const string name_; + int64 work_started_ = 0; + int64 work_finished_ = 0; + int64 num_records_produced_ = 0; + string work_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_READER_BASE_H_ diff --git a/tensorflow/core/kernels/reader_base.proto b/tensorflow/core/kernels/reader_base.proto new file mode 100644 index 0000000000..4335cb2152 --- /dev/null +++ b/tensorflow/core/kernels/reader_base.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +// For serializing and restoring the state of ReaderBase, see +// reader_base.h for details. +message ReaderBaseState { + int64 work_started = 1; + int64 work_finished = 2; + int64 num_records_produced = 3; + bytes current_work = 4; +}; diff --git a/tensorflow/core/kernels/reader_ops.cc b/tensorflow/core/kernels/reader_ops.cc new file mode 100644 index 0000000000..38c1013604 --- /dev/null +++ b/tensorflow/core/kernels/reader_ops.cc @@ -0,0 +1,132 @@ +// See docs in ../ops/io_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/queue_interface.h" +#include "tensorflow/core/framework/reader_interface.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +class ReaderVerbOpKernel : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* context) override { + ReaderInterface* reader; + OP_REQUIRES_OK(context, + GetResourceFromContext(context, "reader_handle", &reader)); + ComputeWithReader(context, reader); + reader->Unref(); + } + + protected: + virtual void ComputeWithReader(OpKernelContext* context, + ReaderInterface* reader) = 0; +}; + +class ReaderReadOp : public ReaderVerbOpKernel { + public: + using ReaderVerbOpKernel::ReaderVerbOpKernel; + + void ComputeWithReader(OpKernelContext* context, + ReaderInterface* reader) override { + QueueInterface* queue; + OP_REQUIRES_OK(context, + GetResourceFromContext(context, "queue_handle", &queue)); + core::ScopedUnref unref_me(queue); + Tensor* key = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("key", TensorShape({}), &key)); + Tensor* value = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("value", TensorShape({}), &value)); + + auto key_scalar = key->scalar(); + auto value_scalar = value->scalar(); + reader->Read(queue, &key_scalar(), &value_scalar(), context); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ReaderRead").Device(DEVICE_CPU), ReaderReadOp); + +class ReaderNumRecordsProducedOp : public ReaderVerbOpKernel { + public: + using ReaderVerbOpKernel::ReaderVerbOpKernel; + + void ComputeWithReader(OpKernelContext* context, + ReaderInterface* reader) override { + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output("records_produced", + TensorShape({}), &output)); + output->scalar()() = reader->NumRecordsProduced(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProduced").Device(DEVICE_CPU), + ReaderNumRecordsProducedOp); + +class ReaderNumWorkUnitsCompletedOp : public ReaderVerbOpKernel { + public: + using ReaderVerbOpKernel::ReaderVerbOpKernel; + + void ComputeWithReader(OpKernelContext* context, + ReaderInterface* reader) override { + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output("units_completed", + TensorShape({}), &output)); + output->scalar()() = reader->NumWorkUnitsCompleted(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ReaderNumWorkUnitsCompleted").Device(DEVICE_CPU), + ReaderNumWorkUnitsCompletedOp); + +class ReaderSerializeStateOp : public ReaderVerbOpKernel { + public: + using ReaderVerbOpKernel::ReaderVerbOpKernel; + + void ComputeWithReader(OpKernelContext* context, + ReaderInterface* reader) override { + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("state", TensorShape({}), &output)); + OP_REQUIRES_OK(context, + reader->SerializeState(&output->scalar()())); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ReaderSerializeState").Device(DEVICE_CPU), + ReaderSerializeStateOp); + +class ReaderRestoreStateOp : public ReaderVerbOpKernel { + public: + using ReaderVerbOpKernel::ReaderVerbOpKernel; + + void ComputeWithReader(OpKernelContext* context, + ReaderInterface* reader) override { + const Tensor* tensor; + OP_REQUIRES_OK(context, context->input("state", &tensor)); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(tensor->shape()), + errors::InvalidArgument("Reader state must be scalar, but had shape: ", + tensor->shape().DebugString())); + OP_REQUIRES_OK(context, reader->RestoreState(tensor->scalar()())); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ReaderRestoreState").Device(DEVICE_CPU), + ReaderRestoreStateOp); + +class ReaderResetOp : public ReaderVerbOpKernel { + public: + using ReaderVerbOpKernel::ReaderVerbOpKernel; + + void ComputeWithReader(OpKernelContext* context, + ReaderInterface* reader) override { + OP_REQUIRES_OK(context, reader->Reset()); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ReaderReset").Device(DEVICE_CPU), ReaderResetOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h new file mode 100644 index 0000000000..b412617a65 --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops.h @@ -0,0 +1,66 @@ +#ifndef TENSORFLOW_KERNELS_REDUCTION_OPS_H_ +#define TENSORFLOW_KERNELS_REDUCTION_OPS_H_ + +// Functor definitions for Reduction ops, must be compilable by nvcc. + +#include +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// When eigen3 has better implementation of AllReducer and AnyReducer, +// replaces reducers here. + +// Reduction using logical_and. +struct AllReducer { + // TODO(zhifengc): Implement PacketAccess when performance matters. + static const bool PacketAccess = false; + static const bool IsStateful = false; + + EIGEN_DEVICE_FUNC void reduce(const bool t, bool* accum) const { + *accum &= t; + } + + EIGEN_DEVICE_FUNC bool initialize() const { return true; } + + EIGEN_DEVICE_FUNC bool finalize(const bool accum) const { return accum; } +}; + +// Reduction using logical_or. +struct AnyReducer { + // TODO(zhifengc): Implement PacketAccess when performance matters. + static const bool PacketAccess = false; + static const bool IsStateful = false; + + EIGEN_DEVICE_FUNC void reduce(const bool t, bool* accum) const { + *accum |= t; + } + + EIGEN_DEVICE_FUNC bool initialize() const { return false; } + + EIGEN_DEVICE_FUNC bool finalize(const bool accum) const { return accum; } +}; + +template +void ReduceEigenImpl(const Device& d, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Reducer& reducer) { + out.device(d) = in.reduce(reduction_axes, reducer); +} + +template +struct ReduceFunctor { + template + static void Reduce(const Device& d, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Reducer& reducer); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_REDUCTION_OPS_H_ diff --git a/tensorflow/core/kernels/reduction_ops_all.cc b/tensorflow/core/kernels/reduction_ops_all.cc new file mode 100644 index 0000000000..11d399e70a --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_all.cc @@ -0,0 +1,17 @@ +#include "tensorflow/core/kernels/reduction_ops_common.h" + +namespace tensorflow { + +REGISTER_KERNEL_BUILDER(Name("All") + .Device(DEVICE_CPU) + .HostMemory("reduction_indices"), + ReductionOp); + +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("All") + .Device(DEVICE_GPU) + .HostMemory("reduction_indices"), + ReductionOp); +#endif + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_any.cc b/tensorflow/core/kernels/reduction_ops_any.cc new file mode 100644 index 0000000000..a89ef22b08 --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_any.cc @@ -0,0 +1,17 @@ +#include "tensorflow/core/kernels/reduction_ops_common.h" + +namespace tensorflow { + +REGISTER_KERNEL_BUILDER(Name("Any") + .Device(DEVICE_CPU) + .HostMemory("reduction_indices"), + ReductionOp); + +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("Any") + .Device(DEVICE_GPU) + .HostMemory("reduction_indices"), + ReductionOp); +#endif + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h new file mode 100644 index 0000000000..2bde3a1a54 --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -0,0 +1,302 @@ +// This is an internal header file intended to only be included as the +// front-matter in the implementation files of various reduction ops. It +// is a header file because we split the various reduction ops into their +// own compilation units to get more parallelism in compilation. + +#ifndef TENSORFLOW_KERNELS_REDUCTION_OPS_COMMON_H_ +#define TENSORFLOW_KERNELS_REDUCTION_OPS_COMMON_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/reduction_ops.h" + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +struct Constants { + // Derive Index type. int (32-bit) or long (64-bit) depending on the + // compile-time configuration. "float" here is not relevant. + // TODO(zhifengc): Moves the definition to TTypes. + typedef TTypes::Tensor::Index Index; + Eigen::array kZero; + Eigen::array kOne; + Eigen::array kZeroTwo; + + Constants() { + kZero[0] = 0; + kOne[0] = 1; + kZeroTwo[0] = 0; + kZeroTwo[1] = 2; + } +}; + +#if defined(EIGEN_HAS_INDEX_LIST) +template <> +struct Constants { + const Eigen::IndexList> kZero; + const Eigen::IndexList> kOne; + const Eigen::IndexList, Eigen::type2index<2>> kZeroTwo; +}; +#endif + +namespace { + +class ReductionHelper { + public: + ReductionHelper() : reduce_first_axis_(false) {} + + Status Simplify(const Tensor& data, const Tensor& axis, + const bool keep_dims) { + // bitmap[i] indicates whether to reduce data along i-th axis. + std::vector bitmap(data.dims(), false); + auto axis_vec = axis.flat(); + for (int64 i = 0; i < axis.NumElements(); ++i) { + const int32 index = axis_vec(i); + if (index < 0 || index >= data.dims()) { + return errors::OutOfRange("Invalid reduction dimension (", index, + " for input with ", data.dims(), + " dimension(s)"); + } + bitmap[index] = true; + } + + // Output tensor's dim sizes. + out_shape_.clear(); + for (int i = 0; i < data.dims(); ++i) { + if (!bitmap[i]) { + // If we are not reducing along dimension i. + out_shape_.push_back(data.dim_size(i)); + } else if (keep_dims) { + // We are reducing along dimension i, but we want to keep the + // same number of dimensions, so we set the dimension of i to + // '1'. + out_shape_.push_back(1); + } + } + + // Depending on bitmap[i] and bitmap[i-1], we can collapse axis of + // the input data before doing the reduction on the resulting + // tensor. The shape of the reduction is a reshape of the final + // output. + + // We'll skip the leading 1s. + int dim_index = 0; + for (; dim_index < data.dims(); ++dim_index) { + if (data.dim_size(dim_index) != 1) break; + } + if (dim_index >= data.dims()) { + // Special case. The input is essentially a scalar. + reduce_first_axis_ = true; + } else { + // Starting from the (dim_index)-th dimension, dimensions + // alternates between runs that need to be reduced and runs that + // don't. + // + // NOTE: If a dimension has size 1, we group it as the current + // run so that we can minimize the number of runs. + // + // E.g., when we want to reduce a tensor of shape [2, 1, 3, 1, + // 5] by axes = [1, 4], we should treat the tensor as a [6, 5] + // and reduce by axes = [1] (i.e., the output is shape [6]). + reduce_first_axis_ = bitmap[dim_index]; + data_reshape_.push_back(data.dim_size(dim_index)); + ++dim_index; + for (; dim_index < data.dims(); ++dim_index) { + const auto size = data.dim_size(dim_index); + if (size == 1) { + bitmap[dim_index] = bitmap[dim_index - 1]; + } + if (bitmap[dim_index - 1] != bitmap[dim_index]) { + // Starts a new run of reduce or !reduce. + data_reshape_.push_back(size); + } else { + // Continue a run of reduce or !reduce. + data_reshape_.back() *= size; + } + } + // If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc + // are reduced), data_reshape_[1, 3, 5, ...] is out_reshape_, + // otherwise, data_reshape_[0, 2, 4, ...] is. + for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size(); + i += 2) { + out_reshape_.push_back(data_reshape_[i]); + } + } + + VLOG(1) << "data reshape: " << str_util::Join(data_reshape_, ","); + VLOG(1) << "out reshape: " << str_util::Join(out_reshape_, ","); + VLOG(1) << "out shape: " << str_util::Join(out_shape_, ","); + return Status::OK(); + } + + // We need to do roughly: + // tmp_out = allocate(out_reshape()) + // tmp_out.reshape(out_reshape) = data.reshape(data_reshape).reduce(axes) + // out = tmp_out.reshape(out_shape) + + // The reduction result must be allocated with this shape. + TensorShape out_reshape() const { + TensorShape shape; + for (auto size : out_reshape_) shape.AddDim(size); + return shape; + } + + // The final output shape must be allocated with this shape. + TensorShape out_shape() const { + TensorShape shape; + for (auto size : out_shape_) shape.AddDim(size); + return shape; + } + + // The reduction is on a reshaped tensor of this rank. + int ndims() const { return data_reshape_.size(); } + + // True if need to reduce the 0-th dimension. + bool reduce_first_axis() const { return reduce_first_axis_; } + + // The output is reshaped. + template + typename TTypes::Tensor out(Tensor* out) { + return out->shaped(out_reshape_); + } + + // The input is reshaped. + template + typename TTypes::ConstTensor in(const Tensor& data) { + return data.shaped(data_reshape_); + } + + private: + bool reduce_first_axis_; // True if need to reduce the 0-th dimension. + std::vector data_reshape_; // Reshape the data before reduction. + std::vector out_shape_; // The final output shape. + std::vector out_reshape_; // Reshape the output for reduction. +}; + +} // end namespace + +// For operations where the output is a reduction function along some +// dimensions of the input. +template +class ReductionOp : public OpKernel { + public: + explicit ReductionOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + const DataType dt = DataTypeToEnum::v(); + OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); + + OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& data = ctx->input(0); + const Tensor& axes = ctx->input(1); + VLOG(1) << "data shape: " << data.shape().ShortDebugString(); + VLOG(1) << "axes : " << axes.SummarizeValue(10); + + ReductionHelper helper; + OP_REQUIRES_OK(ctx, helper.Simplify(data, axes, keep_dims_)); + CHECK_GE(helper.ndims(), 0); + + // The real output shape will be assigned below. + TensorShape empty_shape; + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, empty_shape, &out)); + + if (helper.ndims() == 0 || + (helper.ndims() == 1 && !helper.reduce_first_axis())) { + // Special case. Reduces nothing. It is unclear why this is + // necessary, but tests fail without it. Look into why this + // case occurs. + if (!out->CopyFrom(data, helper.out_shape())) { + ctx->SetStatus(errors::Internal("Error during reduction copy.")); + } + return; + } + + // A temporary tensor whose size matches the size of the reduced + // output. + Tensor tmp_out; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(out->dtype(), helper.out_reshape(), &tmp_out)); + + typedef functor::ReduceFunctor Functor; + Constants constants; + const Device& d = ctx->eigen_device(); + Reducer reducer; + + if ((helper.ndims() == 1) && helper.reduce_first_axis()) { + // Reduce to a scalar. + Functor::Reduce(d, helper.out(&tmp_out), helper.in(data), + constants.kZero, reducer); + } else if ((helper.ndims() == 2) && helper.reduce_first_axis()) { + // Can be viewed as a reduction of a matrix along 1st dimension. + Functor::Reduce(d, helper.out(&tmp_out), helper.in(data), + constants.kZero, reducer); + } else if ((helper.ndims() == 2) && !helper.reduce_first_axis()) { + // Can be viewed as a reduction of a matrix along 2nd dimension. + Functor::Reduce(d, helper.out(&tmp_out), helper.in(data), + constants.kOne, reducer); + } else if ((helper.ndims() == 3) && helper.reduce_first_axis()) { + // Can be viewed as a reduction of a 3D tensor along 1st and 3rd + // dimensions. + Functor::Reduce(d, helper.out(&tmp_out), helper.in(data), + constants.kZeroTwo, reducer); + } else if ((helper.ndims() == 3) && !helper.reduce_first_axis()) { + // Can be viewed as a reduction of a 3D tensor along 2nd dimension. + Functor::Reduce(d, helper.out(&tmp_out), helper.in(data), + constants.kOne, reducer); + } else { + // TODO(zhifengc): We can implement reduction for arbitrary rank + // tensor and arbitrary reduction axes by iterating the reduction + // multiple times. This may also be accomplished in the graph + // construction. + ctx->SetStatus( + errors::Unimplemented("Reducing ", data.shape().ShortDebugString(), + " axes [", axes.SummarizeValue(10), "] to ", + tmp_out.shape().ShortDebugString())); + return; + } + + // Set the real output using the contents of the reduction but the + // real expected output shape. The number of elements should + // match between the two shapes. + if (!out->CopyFrom(tmp_out, helper.out_shape())) { + ctx->SetStatus(errors::Internal("Error during reduction copy.")); + } + } + + private: + // True if the number of dimensions should be maintained. + bool keep_dims_; +}; + +namespace functor { + +template <> +struct ReduceFunctor { + template + static void Reduce(const CPUDevice& d, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Reducer& reducer) { + ReduceEigenImpl(d, out, in, reduction_axes, reducer); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_REDUCTION_OPS_COMMON_H_ diff --git a/tensorflow/core/kernels/reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu.cu.cc new file mode 100644 index 0000000000..8e29d2d06c --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_gpu.cu.cc @@ -0,0 +1,65 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/kernels/reduction_ops.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; + +// Derive Index type. int (32-bit) or long (64-bit) depending on the +// compile-time configuration. "float" here is not relevant. +// TODO(zhifengc): Moves the definition to TTypes. +typedef TTypes::Tensor::Index Index; + +template <> +struct ReduceFunctor { + template + static void Reduce(const GPUDevice& d, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Reducer& reducer) { + ReduceEigenImpl(d, To32Bit(out), To32Bit(in), reduction_axes, reducer); + } +}; + +// T: the data type +// REDUCER: the reducer functor +// NUM_AXES: the number of axes to reduce +// IN_DIMS: the number of dimensions of the input tensor +#define DEFINE(T, REDUCER, IN_DIMS, NUM_AXES) \ + template void ReduceFunctor::Reduce( \ + const GPUDevice& d, TTypes::Tensor out, \ + TTypes::ConstTensor in, \ + const Eigen::array& reduction_axes, \ + const REDUCER& reducer); + +#define DEFINE_FOR_TYPE_AND_R(T, R) \ + DEFINE(T, R, 1, 1); \ + DEFINE(T, R, 2, 1); \ + DEFINE(T, R, 3, 1); \ + DEFINE(T, R, 3, 2); + +#define DEFINE_FOR_ALL_REDUCERS(T) \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer); \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer); \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer) + +DEFINE_FOR_ALL_REDUCERS(float); +#undef DEFINE_FOR_ALL_REDUCERS + +DEFINE_FOR_TYPE_AND_R(complex64, Eigen::internal::SumReducer); +DEFINE_FOR_TYPE_AND_R(bool, AllReducer); +DEFINE_FOR_TYPE_AND_R(bool, AnyReducer); +#undef DEFINE_FOR_TYPE_AND_R + +#undef DEFINE + +} // end namespace functor +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc new file mode 100644 index 0000000000..1749360b6e --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_max.cc @@ -0,0 +1,26 @@ +#include "tensorflow/core/kernels/reduction_ops_common.h" + +namespace tensorflow { + +#define REGISTER_CPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Max").Device(DEVICE_CPU).TypeConstraint("T"), \ + ReductionOp>); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); +#undef REGISTER_CPU_KERNELS + +#if GOOGLE_CUDA + +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Max") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); +REGISTER_GPU_KERNELS(float); +#undef REGISTER_GPU_KERNELS + +#endif + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_mean.cc b/tensorflow/core/kernels/reduction_ops_mean.cc new file mode 100644 index 0000000000..b00c36fed8 --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_mean.cc @@ -0,0 +1,12 @@ +#include "tensorflow/core/kernels/reduction_ops_common.h" + +namespace tensorflow { + +#define REGISTER_CPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Mean").Device(DEVICE_CPU).TypeConstraint("T"), \ + ReductionOp>); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); +#undef REGISTER_CPU_KERNELS + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_min.cc b/tensorflow/core/kernels/reduction_ops_min.cc new file mode 100644 index 0000000000..de1f4b8520 --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_min.cc @@ -0,0 +1,26 @@ +#include "tensorflow/core/kernels/reduction_ops_common.h" + +namespace tensorflow { + +#define REGISTER_CPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Min").Device(DEVICE_CPU).TypeConstraint("T"), \ + ReductionOp>); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); +#undef REGISTER_CPU_KERNELS + +#if GOOGLE_CUDA + +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Min") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); +REGISTER_GPU_KERNELS(float); +#undef REGISTER_GPU_KERNELS + +#endif + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_prod.cc b/tensorflow/core/kernels/reduction_ops_prod.cc new file mode 100644 index 0000000000..4068c7feda --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_prod.cc @@ -0,0 +1,26 @@ +#include "tensorflow/core/kernels/reduction_ops_common.h" + +namespace tensorflow { + +#define REGISTER_CPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Prod").Device(DEVICE_CPU).TypeConstraint("T"), \ + ReductionOp>); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); +#undef REGISTER_CPU_KERNELS + +#if GOOGLE_CUDA + +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Prod") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); +REGISTER_GPU_KERNELS(float); +#undef REGISTER_GPU_KERNELS + +#endif + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc new file mode 100644 index 0000000000..82d685e225 --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_sum.cc @@ -0,0 +1,37 @@ +#include "tensorflow/core/kernels/reduction_ops_common.h" + +namespace tensorflow { + +#define REGISTER_CPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Sum").Device(DEVICE_CPU).TypeConstraint("T"), \ + ReductionOp>); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); +#undef REGISTER_CPU_KERNELS + +// NOTE: We should have mean(complex64,int32), too. But that needs to +// change Eigen::internal::MeanReducer to cast int to complex. +// We don't see immediate need of mean(complex64,int32) anyway. +REGISTER_KERNEL_BUILDER( + Name("Sum").Device(DEVICE_CPU).TypeConstraint("T"), + ReductionOp>); + +#if GOOGLE_CUDA + +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Sum") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); +REGISTER_GPU_KERNELS(float); +#undef REGISTER_GPU_KERNELS + +REGISTER_KERNEL_BUILDER( + Name("Sum").Device(DEVICE_GPU).TypeConstraint("T"), + ReductionOp>); + +#endif + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_test.cc b/tensorflow/core/kernels/reduction_ops_test.cc new file mode 100644 index 0000000000..d96da3c7f1 --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_test.cc @@ -0,0 +1,73 @@ +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include + +namespace tensorflow { + +// Creates a Graph which "reduce"s a 3D float tensor of "num" elements +// into a scalar. +static Graph* ToScalar(const string& reduce, int num) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor data(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)})); + data.flat().setRandom(); + Tensor axes(DT_INT32, TensorShape({3})); + axes.flat()(0) = 0; + axes.flat()(1) = 1; + axes.flat()(2) = 2; + test::graph::Reduce(g, reduce, test::graph::Constant(g, data), + test::graph::Constant(g, axes)); + return g; +} + +// Creates a bench which reduces a 3D tensor with total "num" floats +// into a scalar on a "device". Runs the bench for "iters" times. +static void ReduceToScalar(int iters, const string& device, + const string& reduce, int num) { + testing::ItemsProcessed(static_cast(iters) * num); + testing::BytesProcessed(static_cast(iters) * num * sizeof(float)); + test::Benchmark(device, ToScalar(reduce, num)).Run(iters); +} + +static void BM_Sum3DToScalarCPU(int iters, int num) { + ReduceToScalar(iters, "cpu", "Sum", num); +} +BENCHMARK(BM_Sum3DToScalarCPU)->Range(1 << 13, 1 << 20); + +static void BM_Max3DToScalarCPU(int iters, int num) { + ReduceToScalar(iters, "cpu", "Max", num); +} +BENCHMARK(BM_Max3DToScalarCPU)->Range(1 << 13, 1 << 20); + +static void BM_Prod3DToScalarCPU(int iters, int num) { + ReduceToScalar(iters, "cpu", "Prod", num); +} +BENCHMARK(BM_Prod3DToScalarCPU)->Range(1 << 13, 1 << 20); + +static void BM_Mean3DToScalarCPU(int iters, int num) { + ReduceToScalar(iters, "cpu", "Mean", num); +} +BENCHMARK(BM_Mean3DToScalarCPU)->Range(1 << 13, 1 << 20); + +static void BM_Sum3DToScalarGPU(int iters, int num) { + ReduceToScalar(iters, "gpu", "Sum", num); +} +BENCHMARK(BM_Sum3DToScalarGPU)->Range(1 << 13, 1 << 20); + +static void BM_Max3DToScalarGPU(int iters, int num) { + ReduceToScalar(iters, "gpu", "Max", num); +} +BENCHMARK(BM_Max3DToScalarGPU)->Range(1 << 13, 1 << 20); + +static void BM_Prod3DToScalarGPU(int iters, int num) { + ReduceToScalar(iters, "gpu", "Prod", num); +} +BENCHMARK(BM_Prod3DToScalarGPU)->Range(1 << 13, 1 << 20); + +// Once Mean is available on GPU, enable this. +// static void BM_Mean3DToScalarGPU(int iters, int num) { +// ReduceToScalar(iters, "gpu", "Mean", num); +// } +// BENCHMARK(BM_Mean3DToScalarGPU)->Range(1 << 13, 1 << 20); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/reference_gemm.h b/tensorflow/core/kernels/reference_gemm.h new file mode 100644 index 0000000000..77c6ef35e9 --- /dev/null +++ b/tensorflow/core/kernels/reference_gemm.h @@ -0,0 +1,75 @@ +#ifndef TENSORFLOW_KERNELS_REFERENCE_GEMM_H_ +#define TENSORFLOW_KERNELS_REFERENCE_GEMM_H_ + +// This is an unoptimized but debuggable implementation of the GEMM matrix +// multiply function, used to compare to faster but more opaque versions, or +// for bit depths or argument combinations that aren't supported by optimized +// code. +// It assumes the row-major convention used by TensorFlow, and implements +// C = A * B, like the standard BLAS GEMM interface. If the tranpose flags are +// true, then the relevant matrix is treated as stored in column-major order. + +namespace tensorflow { +template +void ReferenceGemm(bool transpose_a, bool transpose_b, bool transpose_c, + size_t m, size_t n, size_t k, const T1* a, T1 offset_a, + size_t lda, const T2* b, T2 offset_b, size_t ldb, T3* c, + int32 shift_c, int32 offset_c, int32 mult_c, size_t ldc) { + int a_i_stride; + int a_l_stride; + if (transpose_a) { + a_i_stride = 1; + a_l_stride = lda; + } else { + a_i_stride = lda; + a_l_stride = 1; + } + int b_j_stride; + int b_l_stride; + if (transpose_b) { + b_j_stride = ldb; + b_l_stride = 1; + } else { + b_j_stride = 1; + b_l_stride = ldb; + } + int c_i_stride; + int c_j_stride; + if (transpose_c) { + c_i_stride = 1; + c_j_stride = ldc; + } else { + c_i_stride = ldc; + c_j_stride = 1; + } + + const int32 highest = static_cast(Eigen::NumTraits::highest()); + const int32 lowest = static_cast(Eigen::NumTraits::lowest()); + const int32 rounding = (shift_c < 1) ? 0 : (1 << (shift_c - 1)); + + int i, j, l; + for (j = 0; j < n; j++) { + for (i = 0; i < m; i++) { + int32 total = 0; + for (l = 0; l < k; l++) { + const size_t a_index = ((i * a_i_stride) + (l * a_l_stride)); + const int32 a_value = a[a_index] - offset_a; + const size_t b_index = ((j * b_j_stride) + (l * b_l_stride)); + const int32 b_value = b[b_index] - offset_b; + total += (a_value * b_value); + } + const size_t c_index = ((i * c_i_stride) + (j * c_j_stride)); + int32_t output = ((((total + offset_c) * mult_c) + rounding) >> shift_c); + if (output > highest) { + output = highest; + } + if (output < lowest) { + output = lowest; + } + c[c_index] = static_cast(output); + } + } +} +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_REFERENCE_GEMM_H_ diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc new file mode 100644 index 0000000000..d5dd7a8119 --- /dev/null +++ b/tensorflow/core/kernels/relu_op.cc @@ -0,0 +1,154 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/relu_op.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class ReluOp : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Relu functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +template +class Relu6Op : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Relu6 functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +template +class ReluGradOp : public BinaryElementWiseOp> { + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): inputs that were passed to ReluOp() + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OP_REQUIRES(context, a.IsSameSize(g), + errors::InvalidArgument("g and a must be the same size")); + functor::ReluGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); + } +}; + +template +class Relu6GradOp : public BinaryElementWiseOp> { + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): inputs that were passed to Relu6Op() + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OP_REQUIRES(context, a.IsSameSize(g), + errors::InvalidArgument("g and a must be the same size")); + functor::Relu6Grad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_CPU).TypeConstraint("T"), \ + ReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_CPU).TypeConstraint("T"), \ + Relu6Op); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + ReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint("T"), \ + Relu6GradOp) + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Relu::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Relu; \ + \ + template <> \ + void ReluGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor backprops); \ + \ + extern template struct ReluGrad; \ + template <> \ + void Relu6::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Relu6; \ + \ + template <> \ + void Relu6Grad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor backprops); \ + extern template struct Relu6Grad; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_GPU).TypeConstraint("T"), \ + ReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_GPU).TypeConstraint("T"), \ + Relu6Op); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + ReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint("T"), \ + Relu6GradOp) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +#undef REGISTER_GPU_KERNELS + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h new file mode 100644 index 0000000000..8ed071cc4a --- /dev/null +++ b/tensorflow/core/kernels/relu_op.h @@ -0,0 +1,79 @@ +#ifndef TENSORFLOW_KERNELS_RELU_OP_H_ +#define TENSORFLOW_KERNELS_RELU_OP_H_ +// Functor definition for ReluOp and ReluGradOp, must be compilable by nvcc. + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by ReluOp to do the computations. +template +struct Relu { + // Computes Relu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + activations.device(d) = features.cwiseMax(static_cast(0)); + } +}; + +// Functor used by ReluGradOp to do the computations. +template +struct ReluGrad { + // Computes ReluGrad backprops. + // + // gradients: gradients backpropagated to the Relu op. + // features: inputs that where passed to the Relu op. + // backprops: gradients to backpropagate to the Relu inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + // NOTE: When the activation is exactly zero, we arbitrarily choose to not + // propagate the associated gradient value. + backprops.device(d) = + gradients * (features > features.constant(static_cast(0))); + } +}; + +// Functor used by Relu6Op to do the computations. +template +struct Relu6 { + // Computes Relu6 activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + activations.device(d) = + features.cwiseMax(static_cast(0)).cwiseMin(static_cast(6)); + } +}; + +// Functor used by ReluGradOp to do the computations. +template +struct Relu6Grad { + // Computes Relu6Grad backprops. + // + // gradients: gradients backpropagated to the Relu6 op. + // features: inputs that where passed to the Relu6 op. + // backprops: gradients to backpropagate to the Relu6 inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + // NOTE: When the activation is exactly zero or six, we + // arbitrarily choose to not propagate the associated gradient + // value. + backprops.device(d) = gradients * + (features > features.constant(static_cast(0))) * + (features < features.constant(static_cast(6))); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_RELU_OP_H_ diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc new file mode 100644 index 0000000000..6bd87ff8e4 --- /dev/null +++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc @@ -0,0 +1,27 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include + +#include "tensorflow/core/kernels/relu_op.h" + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// Definition of the GPU implementations declared in relu_op.cc. +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Relu; \ + template struct functor::ReluGrad; \ + template struct functor::Relu6; \ + template struct functor::Relu6Grad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/reshape_op.cc b/tensorflow/core/kernels/reshape_op.cc new file mode 100644 index 0000000000..7e1cf029de --- /dev/null +++ b/tensorflow/core/kernels/reshape_op.cc @@ -0,0 +1,29 @@ +// See docs in ../ops/array_ops.cc. +#include "tensorflow/core/kernels/reshape_op.h" + +namespace tensorflow { + +REGISTER_KERNEL_BUILDER(Name("Reshape").Device(DEVICE_CPU).HostMemory("shape"), + ReshapeOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Reshape") \ + .Device(DEVICE_GPU) \ + .HostMemory("shape") \ + .TypeConstraint("T"), \ + ReshapeOp); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Reshape") + .Device(DEVICE_GPU) + .HostMemory("tensor") + .HostMemory("shape") + .HostMemory("output") + .TypeConstraint("T"), + ReshapeOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h new file mode 100644 index 0000000000..3fd3f4492e --- /dev/null +++ b/tensorflow/core/kernels/reshape_op.h @@ -0,0 +1,83 @@ +#ifndef TENSORFLOW_KERNELS_RESHAPE_OP_H_ +#define TENSORFLOW_KERNELS_RESHAPE_OP_H_ + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +class ReshapeOp : public OpKernel { + public: + explicit ReshapeOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& sizes = context->input(1); + // Preliminary validation of sizes. + OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(sizes.shape()), + errors::InvalidArgument("sizes input must be 1-D, not shape ", + sizes.shape().ShortDebugString())); + const int64 num_dims = sizes.NumElements(); + OP_REQUIRES( + context, num_dims <= 8, + errors::InvalidArgument(num_dims, " > max 8 output dims supported")); + + // Compute the output shape. Determine product of specified + // dimensions, and find the index of the unspecified one. + TensorShape shape; + int32 product = 1; + int unknown_index = -1; + auto Svec = sizes.flat(); + for (int d = 0; d < num_dims; ++d) { + const int32 size = Svec(d); + if (size == -1) { + OP_REQUIRES( + context, unknown_index == -1, + errors::InvalidArgument("only one input size may be -1, not both ", + unknown_index, " and ", d)); + unknown_index = d; + shape.AddDim(1); + } else { + OP_REQUIRES(context, size >= 0, + errors::InvalidArgument( + "size ", d, " must be non-negative, not ", size)); + shape.AddDim(size); + product *= size; + } + } + if (unknown_index != -1) { + OP_REQUIRES( + context, product > 0, + errors::InvalidArgument("cannot infer the missing input size for " + "an empty tensor unless all specified " + "input sizes are non-zero")); + const int32 missing = input.NumElements() / product; + OP_REQUIRES(context, product * missing == input.NumElements(), + errors::InvalidArgument("Input has ", input.NumElements(), + " values, which isn't divisible by ", + product)); + shape.set_dim(unknown_index, missing); + } + OP_REQUIRES(context, shape.num_elements() == input.NumElements(), + errors::InvalidArgument("Input has ", input.NumElements(), + " values, which isn't the same as ", + shape.num_elements())); + + // Actually produce the reshaped output. + Tensor output(input.dtype()); + CHECK(output.CopyFrom(input, shape)); + context->set_output(0, output); + } + + bool IsExpensive() override { return false; } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_RESHAPE_OP_H_ diff --git a/tensorflow/core/kernels/resize_area_op.cc b/tensorflow/core/kernels/resize_area_op.cc new file mode 100644 index 0000000000..2b22d38ad6 --- /dev/null +++ b/tensorflow/core/kernels/resize_area_op.cc @@ -0,0 +1,139 @@ +// See docs in ../ops/image_ops.cc +#define EIGEN_USE_THREADS + +#include +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class ResizeAreaOp : public OpKernel { + public: + explicit ResizeAreaOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().ShortDebugString())); + const Tensor& shape_t = context->input(1); + OP_REQUIRES(context, shape_t.dims() == 1, + errors::InvalidArgument("shape_t must be 1-dimensional", + shape_t.shape().ShortDebugString())); + OP_REQUIRES(context, shape_t.NumElements() == 2, + errors::InvalidArgument("shape_t must have two elements", + shape_t.shape().ShortDebugString())); + + auto Svec = shape_t.vec(); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({input.dim_size(0), Svec(0), + Svec(1), input.dim_size(3)}), + &output)); + const int64 batch_size = input.dim_size(0); + const int64 in_height = input.dim_size(1); + const int64 in_width = input.dim_size(2); + const int64 channels = input.dim_size(3); + const int64 out_height = output->dim_size(1); + const int64 out_width = output->dim_size(2); + + typename TTypes::ConstTensor input_data = input.tensor(); + typename TTypes::Tensor output_data = output->tensor(); + + // A temporary tensor for computing the sum. + Tensor sum_tensor; + OP_REQUIRES_OK( + context, context->allocate_temp(DataTypeToEnum::value, + TensorShape({channels}), &sum_tensor)); + typename TTypes::Tensor sum_data = sum_tensor.vec(); + + const float height_scale = in_height / static_cast(out_height); + const float width_scale = in_width / static_cast(out_width); + + // When using this algorithm for downsizing, the target pixel value is the + // weighted average of all the source pixels. The weight is determined by + // the contribution percentage of the source pixel. + // + // Let "scale" be "target_image_size/source_image_size". If 1/n of the + // source pixel contributes to the target pixel, then the weight is (1/n * + // scale); if the complete source pixel contributes to the target pixel, + // then the weight is scale. + // + // To visualize the implementation, use one dimension as an example: + // Resize in[4] to out[3]. + // scale = 3/4 = 0.75 + // out[0]: in[0] and 1/3 of in[1] + // out[1]: 2/3 of in[1] and 2/3 of in[2] + // out[2]: 1/3 of in[2] and in[1] + // Hence, the output pixel values are: + // out[0] = (in[0] * 1.0 + in[1] * 1/3) * scale + // out[1] = (in[1] * 2/3 + in[2] * 2/3 * scale + // out[2] = (in[3] * 1/3 + in[3] * 1.0) * scale + float scale = 1.0 / (height_scale * width_scale); + for (int64 b = 0; b < batch_size; ++b) { + for (int64 y = 0; y < out_height; ++y) { + const float in_y = y * height_scale; + const float in_y1 = (y + 1) * height_scale; + // The start and end height indices of all the cells that could + // contribute to the target cell. + int64 y_start = floor(in_y); + int64 y_end = ceil(in_y1); + + for (int64 x = 0; x < out_width; ++x) { + const float in_x = x * width_scale; + const float in_x1 = (x + 1) * width_scale; + // The start and end width indices of all the cells that could + // contribute to the target cell. + int64 x_start = floor(in_x); + int64 x_end = ceil(in_x1); + + sum_data.setConstant(0.0); + for (int64 i = y_start; i < y_end; ++i) { + float scale_y = + i < in_y ? i + 1 - in_y : (i + 1 > in_y1 ? in_y1 - i : 1.0); + for (int64 j = x_start; j < x_end; ++j) { + float scale_x = + j < in_x ? j + 1 - in_x : (j + 1 > in_x1 ? in_x1 - j : 1.0); + for (int64 c = 0; c < channels; ++c) { +#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val)))) + sum_data(c) += + input_data(b, BOUND(i, in_height), BOUND(j, in_width), c) * + scale_y * scale_x * scale; +#undef BOUND + } + } + } + for (int64 c = 0; c < channels; ++c) { + output_data(b, y, x, c) = sum_data(c); + } + } + } + } + } +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("ResizeArea") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("size"), \ + ResizeAreaOp); + +REGISTER_KERNEL(uint8); +REGISTER_KERNEL(int8); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc new file mode 100644 index 0000000000..472fc19b82 --- /dev/null +++ b/tensorflow/core/kernels/resize_bicubic_op.cc @@ -0,0 +1,121 @@ +// See docs in ../ops/image_ops.cc +#define EIGEN_USE_THREADS + +#include +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class ResizeBicubicOp : public OpKernel { + public: + explicit ResizeBicubicOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().ShortDebugString())); + const Tensor& shape_t = context->input(1); + OP_REQUIRES(context, shape_t.dims() == 1, + errors::InvalidArgument("shape_t must be 1-dimensional", + shape_t.shape().ShortDebugString())); + OP_REQUIRES(context, shape_t.NumElements() == 2, + errors::InvalidArgument("shape_t must have two elements", + shape_t.shape().ShortDebugString())); + + auto Svec = shape_t.vec(); + // Initialize shape to the batch size of the input, then add + // the rest of the dimensions + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({input.dim_size(0), Svec(0), + Svec(1), input.dim_size(3)}), + &output)); + const int64 batch_size = input.dim_size(0); + const int64 in_height = input.dim_size(1); + const int64 in_width = input.dim_size(2); + const int64 channels = input.dim_size(3); + const int64 out_height = output->dim_size(1); + const int64 out_width = output->dim_size(2); + + typename TTypes::ConstTensor input_data = input.tensor(); + typename TTypes::Tensor output_data = output->tensor(); + + const float height_scale = in_height / static_cast(out_height); + const float width_scale = in_width / static_cast(out_width); + + // Initialize coefficients table using Bicubic convolution algorithm. + // https://en.wikipedia.org/wiki/Bicubic_interpolation + static const int64 tab_size = (1 << 10); + static float coeffs_tab[(tab_size + 1) * 2]; + static const double A = -0.75; + for (int i = 0; i <= tab_size; ++i) { + float x = i * 1.0 / tab_size; + coeffs_tab[i * 2] = ((A + 2) * x - (A + 3)) * x * x + 1; + x += 1.0; + coeffs_tab[i * 2 + 1] = ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; + } + + auto cal = [](float v0, float v1, float v2, float v3, float dx) { + const int64 offset = round(dx * tab_size); + const float a0 = coeffs_tab[offset * 2 + 1]; + const float a1 = coeffs_tab[offset * 2]; + const float a2 = coeffs_tab[(tab_size - offset) * 2]; + const float a3 = coeffs_tab[(tab_size - offset) * 2 + 1]; + return a0 * v0 + a1 * v1 + a2 * v2 + a3 * v3; + }; + + float coeff[4] = {0.0}; + for (int64 b = 0; b < batch_size; ++b) { + for (int64 y = 0; y < out_height; ++y) { + const int64 in_y = floor(height_scale * y); + const float dy = height_scale * y - in_y; + for (int64 x = 0; x < out_width; ++x) { + const int64 in_x = floor(width_scale * x); + const float dx = width_scale * x - in_x; + for (int64 c = 0; c < channels; ++c) { + for (int64 i = 0; i < 4; ++i) { +#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val)))) + int64 bound_y = BOUND(in_y - 1 + i, in_height); + coeff[i] = + cal(input_data(b, bound_y, BOUND(in_x - 1, in_width), c), + input_data(b, bound_y, BOUND(in_x, in_width), c), + input_data(b, bound_y, BOUND(in_x + 1, in_width), c), + input_data(b, bound_y, BOUND(in_x + 2, in_width), c), dx); +#undef BOUND + } + output_data(b, y, x, c) = + cal(coeff[0], coeff[1], coeff[2], coeff[3], dy); + } + } + } + } + } +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("ResizeBicubic") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("size"), \ + ResizeBicubicOp); + +REGISTER_KERNEL(uint8); +REGISTER_KERNEL(int8); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/resize_bilinear_op.cc b/tensorflow/core/kernels/resize_bilinear_op.cc new file mode 100644 index 0000000000..5119b93508 --- /dev/null +++ b/tensorflow/core/kernels/resize_bilinear_op.cc @@ -0,0 +1,109 @@ +// See docs in ../ops/image_ops.cc +#define EIGEN_USE_THREADS + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class ResizeBilinearOp : public OpKernel { + public: + explicit ResizeBilinearOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().ShortDebugString())); + const Tensor& shape_t = context->input(1); + OP_REQUIRES(context, shape_t.dims() == 1, + errors::InvalidArgument("shape_t must be 1-dimensional", + shape_t.shape().ShortDebugString())); + OP_REQUIRES(context, shape_t.NumElements() == 2, + errors::InvalidArgument("shape_t must have two elements", + shape_t.shape().ShortDebugString())); + + auto Svec = shape_t.vec(); + // Initialize shape to the batch size of the input, then add + // the rest of the dimensions + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({input.dim_size(0), Svec(0), + Svec(1), input.dim_size(3)}), + &output)); + + const int64 batch_size = input.dim_size(0); + const int64 in_height = input.dim_size(1); + const int64 in_width = input.dim_size(2); + const int64 channels = input.dim_size(3); + const int64 out_height = output->dim_size(1); + const int64 out_width = output->dim_size(2); + + typename TTypes::ConstTensor input_data = input.tensor(); + typename TTypes::Tensor output_data = output->tensor(); + + const float height_scale = in_height / static_cast(out_height); + const float width_scale = in_width / static_cast(out_width); + + for (int b = 0; b < batch_size; ++b) { + for (int y = 0; y < out_height; ++y) { + const float in_y = y * height_scale; + const int top_y_index = static_cast(floorf(in_y)); + const int bottom_y_index = + std::min(static_cast(ceilf(in_y)), (in_height - 1)); + const float y_lerp = in_y - top_y_index; + const float inverse_y_lerp = (1.0f - y_lerp); + for (int x = 0; x < out_width; ++x) { + const float in_x = x * width_scale; + const int left_x_index = static_cast(floorf(in_x)); + const int right_x_index = + std::min(static_cast(ceilf(in_x)), (in_width - 1)); + const float x_lerp = in_x - left_x_index; + const float inverse_x_lerp = (1.0f - x_lerp); + for (int c = 0; c < channels; ++c) { + const float top_left = input_data(b, top_y_index, left_x_index, c); + const float top_right = + input_data(b, top_y_index, right_x_index, c); + const float bottom_left = + input_data(b, bottom_y_index, left_x_index, c); + const float bottom_right = + input_data(b, bottom_y_index, right_x_index, c); + const float top = + (top_left * inverse_x_lerp) + (top_right * x_lerp); + const float bottom = + (bottom_left * inverse_x_lerp) + (bottom_right * x_lerp); + output_data(b, y, x, c) = + (top * inverse_y_lerp) + (bottom * y_lerp); + } + } + } + } + } +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("ResizeBilinear") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("size"), \ + ResizeBilinearOp); + +REGISTER_KERNEL(uint8); +REGISTER_KERNEL(int8); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/resize_bilinear_op_test.cc b/tensorflow/core/kernels/resize_bilinear_op_test.cc new file mode 100644 index 0000000000..0ebe2e5f8c --- /dev/null +++ b/tensorflow/core/kernels/resize_bilinear_op_test.cc @@ -0,0 +1,171 @@ +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/public/tensor.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { + +class ResizeBilinearOpTest : public OpsTestBase { + protected: + ResizeBilinearOpTest() { + RequireDefaultOps(); + EXPECT_OK(NodeDefBuilder("resize_bilinear_op", "ResizeBilinear") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT32)) + .Finalize(node_def())); + EXPECT_OK(InitOp()); + } +}; + +TEST_F(ResizeBilinearOpTest, TestBilinear2x2To1x1) { + // Input: + // 1, 2 + // 3, 4 + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {1, 1}); + ASSERT_OK(RunOpKernel()); + + // When scaling down, we have to arbitrarily pick a pixel from the + // original input. In this case, we choose the top/left most pixel. + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 1})); + test::FillValues(&expected, {1.0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(ResizeBilinearOpTest, TestBilinear2x2To3x3) { + // Input: + // 1, 2 + // 3, 4 + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {3, 3}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 3, 3, 1})); + + // The corners should match the original corners, and we bilinear + // interpolate the values in between. + + // clang-format off + test::FillValues(&expected, + {1, 5.0/3, 2, + 7.0/3, 3, 10.0/3, + 3, 11.0/3, 4}); + + // clang-format on + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(ResizeBilinearOpTest, TestBilinear3x3To4x4) { + // Input: + // 1, 2, 3, + // 4, 5, 6, + // 7, 8, 9 + AddInputFromArray(TensorShape({1, 3, 3, 1}), + {1, 2, 3, 4, 5, 6, 7, 8, 9}); + AddInputFromArray(TensorShape({2}), {4, 4}); + ASSERT_OK(RunOpKernel()); + + // The corners should match the original corners, and we bilinear + // interpolate the values in between. + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 4, 4, 1})); + // clang-format off + test::FillValues(&expected, + {1, 1.75, 2.5, 3, + 3.25, 4, 4.75, 5.25, + 5.5, 6.25, 7, 7.5, + 7, 7.75, 8.5, 9}); + + // clang-format on + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(ResizeBilinearOpTest, TestBilinear2x2To3x3Batch2) { + // Input: + // 1, 2 + // 3, 4 + // + // repeated twice + AddInputFromArray(TensorShape({2, 2, 2, 1}), {1, 2, 3, 4, 1, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {3, 3}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3, 3, 1})); + // clang-format off + test::FillValues(&expected, + {1, 5.0/3, 2, 7.0/3, 3, 10.0/3, 3, 11.0/3, 4, + 1, 5.0/3, 2, 7.0/3, 3, 10.0/3, 3, 11.0/3, 4 + }); + // clang-format on + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(ResizeBilinearOpTest, TestBilinear2x2x2To3x3x2) { + AddInputFromArray(TensorShape({1, 2, 2, 2}), + {1, -1, 2, -2, 3, -3, 4, -4}); + AddInputFromArray(TensorShape({2}), {3, 3}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 3, 3, 2})); + // clang-format off + test::FillValues(&expected, + { + 1, -1, + 5.0/3, -5.0/3, + 2, -2, + 7.0/3, -7.0/3, + 3, -3, + 10.0/3, -10.0/3, + 3, -3, + 11.0/3, -11.0/3, + 4, -4 + }); + // clang-format on + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(ResizeBilinearOpTest, TestBilinear2x2To4x4) { + // Input: + // 1, 2 + // 3, 4 + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {4, 4}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 4, 4, 1})); + // clang-format off + test::FillValues(&expected, + {1, 1.5, 2, 2, + 2, 2.5, 3, 3, + 3, 3.5, 4, 4, + 3, 3.5, 4, 4}); + // clang-format on + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(ResizeBilinearOpTest, TestInvalidInputShape) { + AddInputFromArray(TensorShape({2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {4, 4}); + ASSERT_FALSE(RunOpKernel().ok()); +} + +TEST_F(ResizeBilinearOpTest, TestInvalidSizeDim) { + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({2, 1}), {4, 4}); + ASSERT_FALSE(RunOpKernel().ok()); +} +TEST_F(ResizeBilinearOpTest, TestInvalidSizeElements) { + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({3}), {4, 4, 1}); + ASSERT_FALSE(RunOpKernel().ok()); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc new file mode 100644 index 0000000000..13089308ce --- /dev/null +++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc @@ -0,0 +1,89 @@ +// See docs in ../ops/image_ops.cc +#define EIGEN_USE_THREADS + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class ResizeNearestNeighborOp : public OpKernel { + public: + explicit ResizeNearestNeighborOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().ShortDebugString())); + const Tensor& shape_t = context->input(1); + OP_REQUIRES(context, shape_t.dims() == 1, + errors::InvalidArgument("shape_t must be 1-dimensional", + shape_t.shape().ShortDebugString())); + OP_REQUIRES(context, shape_t.NumElements() == 2, + errors::InvalidArgument("shape_t must have two elements", + shape_t.shape().ShortDebugString())); + + auto Svec = shape_t.vec(); + // Initialize shape to the batch size of the input, then add + // the rest of the dimensions + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({input.dim_size(0), Svec(0), + Svec(1), input.dim_size(3)}), + &output)); + + const int64 batch_size = input.dim_size(0); + const int64 in_height = input.dim_size(1); + const int64 in_width = input.dim_size(2); + const int64 channels = input.dim_size(3); + const int64 out_height = output->dim_size(1); + const int64 out_width = output->dim_size(2); + + typename TTypes::ConstTensor input_data = input.tensor(); + typename TTypes::Tensor output_data = output->tensor(); + + const float height_scale = in_height / static_cast(out_height); + const float width_scale = in_width / static_cast(out_width); + + for (int b = 0; b < batch_size; ++b) { + for (int y = 0; y < out_height; ++y) { + const int in_y = std::min(static_cast(floorf(y * height_scale)), + (in_height - 1)); + for (int x = 0; x < out_width; ++x) { + const int in_x = std::min(static_cast(floorf(x * width_scale)), + (in_width - 1)); + for (int c = 0; c < channels; ++c) { + output_data(b, y, x, c) = input_data(b, in_y, in_x, c); + } + } + } + } + } +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("size"), \ + ResizeNearestNeighborOp); + +REGISTER_KERNEL(uint8); +REGISTER_KERNEL(int8); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op_test.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op_test.cc new file mode 100644 index 0000000000..8fca1f34e3 --- /dev/null +++ b/tensorflow/core/kernels/resize_nearest_neighbor_op_test.cc @@ -0,0 +1,163 @@ +// TODO(shlens, sherrym): Consider adding additional tests in image_ops.py in +// order to compare the reference implementation for image resizing in Python +// Image Library. +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/public/tensor.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { + +class ResizeNearestNeighborOpTest : public OpsTestBase { + protected: + ResizeNearestNeighborOpTest() { + RequireDefaultOps(); + EXPECT_OK(NodeDefBuilder("resize_nn", "ResizeNearestNeighbor") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT32)) + .Finalize(node_def())); + EXPECT_OK(InitOp()); + } +}; + +TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To1x1) { + // Input: + // 1, 2 + // 3, 4 + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {1, 1}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 1})); + + // clang-format off + test::FillValues(&expected, {1}); + + // clang-format on + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To3x3) { + // Input: + // 1, 2 + // 3, 4 + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {3, 3}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 3, 3, 1})); + + // clang-format off + test::FillValues(&expected, + {1, 1, 2, + 1, 1, 2, + 3, 3, 4}); + + // clang-format on + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To2x5) { + // Input: + // 1, 2 + // 3, 4 + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {2, 5}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 5, 1})); + + // clang-format off + test::FillValues(&expected, + {1, 1, 1, 2, 2, + 3, 3, 3, 4, 4}); + + // clang-format on + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To5x2) { + // Input: + // 1, 2 + // 3, 4 + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {5, 2}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 5, 2, 1})); + + // clang-format off + test::FillValues(&expected, + {1, 2, + 1, 2, + 1, 2, + 3, 4, + 3, 4}); + + // clang-format on + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To4x4) { + // Input: + // 1, 2 + // 3, 4 + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {4, 4}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 4, 4, 1})); + + // clang-format off + test::FillValues(&expected, + {1, 1, 2, 2, + 1, 1, 2, 2, + 3, 3, 4, 4, + 3, 3, 4, 4}); + + // clang-format on + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2x2x2To2x3x3x2) { + // Input: + // [ [ 1, 1 ], [ 2, 2], + // [ 3, 3 ], [ 4, 4] ], + // [ [ 5, 5 ], [ 6, 6], + // [ 7, 7 ], [ 8, 8] ] + AddInputFromArray(TensorShape({2, 2, 2, 2}), + {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8}); + AddInputFromArray(TensorShape({2}), {3, 3}); + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3, 3, 2})); + + // clang-format off + test::FillValues(&expected, + {1, 1, 1, + 1, 2, 2, + 1, 1, 1, + 1, 2, 2, + 3, 3, 3, + 3, 4, 4, + 5, 5, 5, + 5, 6, 6, + 5, 5, 5, + 5, 6, 6, + 7, 7, 7, + 7, 8, 8}); + + // clang-format on + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/restore_op.cc b/tensorflow/core/kernels/restore_op.cc new file mode 100644 index 0000000000..b52c69449c --- /dev/null +++ b/tensorflow/core/kernels/restore_op.cc @@ -0,0 +1,65 @@ +// See docs in ../ops/io_ops.cc. +#include "tensorflow/core/kernels/io.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/tensor_slice_reader.h" + +namespace tensorflow { + +class RestoreOp : public OpKernel { + public: + explicit RestoreOp(OpKernelConstruction* context) : OpKernel(context) { + int preferred_shard; + OP_REQUIRES_OK(context, + context->GetAttr("preferred_shard", &preferred_shard)); + if (preferred_shard == -1) { + preferred_shard_ = checkpoint::TensorSliceReader::kLoadAllShards; + } else { + OP_REQUIRES(context, preferred_shard >= 0, + errors::InvalidArgument("Attribute 'preferred_shard' must be " + "greater or equal to -1")); + preferred_shard_ = preferred_shard; + } + } + void Compute(OpKernelContext* context) override { + RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader, + preferred_shard_, false); + } + + private: + int preferred_shard_; +}; + +REGISTER_KERNEL_BUILDER(Name("Restore").Device(DEVICE_CPU), RestoreOp); + +class RestoreSliceOp : public OpKernel { + public: + explicit RestoreSliceOp(OpKernelConstruction* context) : OpKernel(context) { + int preferred_shard; + OP_REQUIRES_OK(context, + context->GetAttr("preferred_shard", &preferred_shard)); + if (preferred_shard == -1) { + preferred_shard_ = checkpoint::TensorSliceReader::kLoadAllShards; + } else { + OP_REQUIRES(context, preferred_shard >= 0, + errors::InvalidArgument("Attribute 'preferred_shard' must be " + "greater or equal to -1")); + preferred_shard_ = preferred_shard; + } + } + void Compute(OpKernelContext* context) override { + RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader, + preferred_shard_, true); + } + + private: + int preferred_shard_; +}; + +REGISTER_KERNEL_BUILDER(Name("RestoreSlice").Device(DEVICE_CPU), + RestoreSliceOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/restore_op_test.cc b/tensorflow/core/kernels/restore_op_test.cc new file mode 100644 index 0000000000..59343a8037 --- /dev/null +++ b/tensorflow/core/kernels/restore_op_test.cc @@ -0,0 +1,305 @@ +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" +#include + +namespace tensorflow { +namespace { + +class RestoreOpTest : public OpsTestBase { + protected: + // Makes an operation to restore two tensors + void MakeRestoreOp(DataType dt) { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "Restore") + .Input(FakeInput()) + .Input(FakeInput()) + .Attr("dt", dt) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(RestoreOpTest, RestoreInt) { + const string filename = io::JoinPath(testing::TmpDir(), "tensor_int"); + const string tensor_name = "tensor_int"; + + // We first need to write a tensor using the save_op + { + // Initialize an operation + NodeDef save; + ASSERT_OK(NodeDefBuilder("save", "Save") + .Input(FakeInput(DT_STRING)) + .Input(FakeInput(DT_STRING)) + .Input(FakeInput({DT_INT32})) + .Finalize(&save)); + + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + + gtl::InlinedVector inputs; + + Status status; + std::unique_ptr op(CreateOpKernel( + DEVICE_CPU, device.get(), cpu_allocator(), save, &status)); + EXPECT_OK(status); + + // Run it + + // Input #0 is the file name + Tensor input_0(DT_STRING, TensorShape({})); + input_0.scalar()() = filename; + inputs.push_back({nullptr, &input_0}); + + // Input #1 is the tensor name + Tensor input_1(DT_STRING, TensorShape({})); + input_1.scalar()() = tensor_name; + inputs.push_back({nullptr, &input_1}); + + // Input #2 is an integer tensor: it's a 1-d array. + Tensor input_2(DT_INT32, TensorShape({10})); + for (int i = 0; i < 10; ++i) { + input_2.flat()(i) = i + 1; + } + inputs.push_back({nullptr, &input_2}); + + OpKernelContext::Params params; + params.device = device.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs; + params.op_kernel = op.get(); + params.output_alloc_attr = [&device, &op, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; + params.slice_reader_cache = &slice_reader_cache_wrapper; + + OpKernelContext ctx(params); + op->Compute(&ctx); + EXPECT_OK(ctx.status()); + } + + // Now we restore + MakeRestoreOp(DT_INT32); + // Add a file name + AddInput(TensorShape({}), + [&filename](int x) -> string { return filename; }); + // Add the tensor names + AddInput(TensorShape({}), + [&tensor_name](int x) -> string { return tensor_name; }); + + ASSERT_OK(RunOpKernel()); + + // Check that we have an integer tensor + Tensor* output = GetOutput(0); + TensorShape expected({10}); + EXPECT_TRUE(output->shape().IsSameSize(expected)); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(i + 1, output->flat()(i)); + } +} + +TEST_F(RestoreOpTest, RestoreFloat) { + const string filename = io::JoinPath(testing::TmpDir(), "tensor_float"); + const string tensor_name = "tensor_float"; + + // We first need to write a tensor using the save_op + { + // Initialize an operation + NodeDef save; + ASSERT_OK(NodeDefBuilder("save", "Save") + .Input(FakeInput(DT_STRING)) + .Input(FakeInput(DT_STRING)) + .Input(FakeInput({DT_FLOAT})) + .Finalize(&save)); + + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + gtl::InlinedVector inputs; + + Status status; + std::unique_ptr op(CreateOpKernel( + DEVICE_CPU, device.get(), cpu_allocator(), save, &status)); + EXPECT_OK(status); + + // Run it + + // Input #0 is the file name + Tensor input_0(DT_STRING, TensorShape({})); + input_0.scalar()() = filename; + inputs.push_back({nullptr, &input_0}); + + // Input #1 is the tensor name + Tensor input_1(DT_STRING, TensorShape({})); + input_1.scalar()() = tensor_name; + inputs.push_back({nullptr, &input_1}); + + // Input #2 is a float tensor: it's a 2-d array. + Tensor input_2(DT_FLOAT, TensorShape({2, 4})); + for (int i = 0; i < 8; ++i) { + input_2.flat()(i) = static_cast(i) / 10; + } + inputs.push_back({nullptr, &input_2}); + + OpKernelContext::Params params; + params.device = device.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs; + params.op_kernel = op.get(); + params.output_alloc_attr = [&device, &op, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; + params.slice_reader_cache = &slice_reader_cache_wrapper; + + OpKernelContext ctx(params); + op->Compute(&ctx); + EXPECT_OK(ctx.status()); + } + + // Now we restore + MakeRestoreOp(DT_FLOAT); + // Add a file name + AddInput(TensorShape({}), + [&filename](int x) -> string { return filename; }); + // Add the tensor names + AddInput(TensorShape({}), + [&tensor_name](int x) -> string { return tensor_name; }); + + ASSERT_OK(RunOpKernel()); + + // Check that we have a float tensor. + Tensor* output = GetOutput(0); + TensorShape expected({2, 4}); + EXPECT_TRUE(output->shape().IsSameSize(expected)); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(static_cast(i) / 10, output->flat()(i)); + } +} + +class RestoreSliceOpTest : public OpsTestBase { + protected: + void MakeRestoreSliceOp(DataType dt) { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "RestoreSlice") + .Input(FakeInput()) + .Input(FakeInput()) + .Input(FakeInput()) + .Attr("dt", dt) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(RestoreSliceOpTest, RestoreInt) { + const string filename = io::JoinPath(testing::TmpDir(), "tensor_int"); + const string tensor_name = "tensor_int"; + + // We first need to write a tensor using the save_op + { + // Initialize an operation + NodeDef save; + ASSERT_OK(NodeDefBuilder("save", "Save") + .Input(FakeInput(DT_STRING)) + .Input(FakeInput(DT_STRING)) + .Input(FakeInput({DT_INT32})) + .Finalize(&save)); + + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + + gtl::InlinedVector inputs; + + Status status; + std::unique_ptr op(CreateOpKernel( + DEVICE_CPU, device.get(), cpu_allocator(), save, &status)); + EXPECT_OK(status); + + // Run it + + // Input #0 is the file name + Tensor input_0(DT_STRING, TensorShape({})); + input_0.scalar()() = filename; + inputs.push_back({nullptr, &input_0}); + + // Input #1 is the tensor name + Tensor input_1(DT_STRING, TensorShape({})); + input_1.scalar()() = tensor_name; + inputs.push_back({nullptr, &input_1}); + + // Input #2 is a 4x16 integer tensor. + Tensor input_2(DT_INT32, TensorShape({4, 16})); + for (int64 i = 0; i < input_2.NumElements(); ++i) { + input_2.flat()(i) = i + 1; + } + inputs.push_back({nullptr, &input_2}); + + OpKernelContext::Params params; + params.device = device.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs; + params.op_kernel = op.get(); + params.output_alloc_attr = [&device, &op, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; + params.slice_reader_cache = &slice_reader_cache_wrapper; + + OpKernelContext ctx(params); + op->Compute(&ctx); + EXPECT_OK(ctx.status()); + } + + // Now we restore + MakeRestoreSliceOp(DT_INT32); + string shape_and_slice = "4 16 0,2:-"; + // Add a file name + AddInput(TensorShape({}), + [&filename](int x) -> string { return filename; }); + // Add the tensor names + AddInput(TensorShape({}), + [&tensor_name](int x) -> string { return tensor_name; }); + // Add the tensor shape and slice + AddInput(TensorShape({}), [&shape_and_slice](int x) -> string { + return shape_and_slice; + }); + + ASSERT_OK(RunOpKernel()); + + // Check that we have an integer tensor + Tensor* output = GetOutput(0); + TensorShape expected({2, 16}); + EXPECT_TRUE(output->shape().IsSameSize(expected)); + for (int64 i = 0; i < expected.num_elements(); ++i) { + EXPECT_EQ(i + 1, output->flat()(i)); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc new file mode 100644 index 0000000000..c63dfc1e70 --- /dev/null +++ b/tensorflow/core/kernels/reverse_op.cc @@ -0,0 +1,139 @@ +// See docs in ../ops/array_ops.cc +#define EIGEN_USE_THREADS + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/reverse_op.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class ReverseOp : public OpKernel { + public: + explicit ReverseOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& dims = context->input(1); + + if (TensorShapeUtils::IsScalar(input.shape())) { + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + output->scalar() = input.scalar(); + + } else { + const int input_dims = input.dims(); + OP_REQUIRES(context, TensorShapeUtils::IsVector(dims.shape()), + errors::InvalidArgument("'dims' must be 1-dimension, not ", + dims.dims())); + + OP_REQUIRES(context, input_dims == dims.dim_size(0), + errors::InvalidArgument( + "'dims' must have the same number of values as 'input' has " + "dimensions. 'input' has ", input_dims, "'dims' has ", + dims.dim_size(0), " values")); + OP_REQUIRES(context, input_dims <= 8, errors::Unimplemented( + "reverse is not implemented for tensors of rank > 8.")); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + +#define HANDLE_REVERSE(NDIMS) \ + case NDIMS: \ + functor::Reverse()( \ + context->eigen_device(), input.tensor(), \ + dims.vec(), output->tensor()); \ + return; + + switch (input_dims) { + HANDLE_REVERSE(0); + HANDLE_REVERSE(1); + HANDLE_REVERSE(2); + HANDLE_REVERSE(3); + HANDLE_REVERSE(4); + HANDLE_REVERSE(5); + HANDLE_REVERSE(6); + HANDLE_REVERSE(7); + HANDLE_REVERSE(8); + } +#undef HANDLE_REVERSE + } + } +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("Reverse") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("dims"), \ + ReverseOp) + +REGISTER_KERNEL(uint8); +REGISTER_KERNEL(int8); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(bool); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA + +// Forward declarations of the function specializations for GPU (to prevent +// building the GPU versions here, they will be built compiling _gpu.cu.cc). +namespace functor { +#define DECLARE_GPU_SPEC_DIM(T, DIM) \ + template <> \ + void Reverse::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + typename TTypes::ConstTensor dims, \ + typename TTypes::Tensor output); \ + extern template struct Reverse; +#define DECLARE_GPU_SPEC(T) \ + DECLARE_GPU_SPEC_DIM(T, 0) \ + DECLARE_GPU_SPEC_DIM(T, 1) \ + DECLARE_GPU_SPEC_DIM(T, 2) \ + DECLARE_GPU_SPEC_DIM(T, 3) \ + DECLARE_GPU_SPEC_DIM(T, 4) \ + DECLARE_GPU_SPEC_DIM(T, 5) \ + DECLARE_GPU_SPEC_DIM(T, 6) \ + DECLARE_GPU_SPEC_DIM(T, 7) \ + DECLARE_GPU_SPEC_DIM(T, 8) + +DECLARE_GPU_SPEC(uint8); +DECLARE_GPU_SPEC(int8); +DECLARE_GPU_SPEC(int32); +DECLARE_GPU_SPEC(bool); +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +#undef DECLARE_GPU_SPEC_DIM +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("Reverse") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("dims"), \ + ReverseOp) +REGISTER_GPU_KERNEL(uint8); +REGISTER_GPU_KERNEL(int8); +REGISTER_GPU_KERNEL(float); +REGISTER_GPU_KERNEL(double); +#undef REGISTER_GPU_KERNEL + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reverse_op.h b/tensorflow/core/kernels/reverse_op.h new file mode 100644 index 0000000000..bba25f70e8 --- /dev/null +++ b/tensorflow/core/kernels/reverse_op.h @@ -0,0 +1,28 @@ +#ifndef TENSORFLOW_KERNELS_REVERSE_OP_H_ +#define TENSORFLOW_KERNELS_REVERSE_OP_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by MirrorOp to do the computations. +template +struct Reverse { + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor dims, + typename TTypes::Tensor output) { + // mirror is in host memory + Eigen::array reverse_dims; + for (int i = 0; i < Dims; ++i) { + reverse_dims[i] = dims(i); + } + output.device(d) = input.reverse(reverse_dims); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_MIRROR_OP_H_ diff --git a/tensorflow/core/kernels/reverse_op_gpu.cu.cc b/tensorflow/core/kernels/reverse_op_gpu.cu.cc new file mode 100644 index 0000000000..b510add3f3 --- /dev/null +++ b/tensorflow/core/kernels/reverse_op_gpu.cu.cc @@ -0,0 +1,33 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/reverse_op.h" + +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +#define DEFINE_REVERSE(DIM) \ + template struct functor::Reverse; \ + template struct functor::Reverse; \ + template struct functor::Reverse; \ + template struct functor::Reverse; \ + template struct functor::Reverse; \ + template struct functor::Reverse; +DEFINE_REVERSE(0) +DEFINE_REVERSE(1) +DEFINE_REVERSE(2) +DEFINE_REVERSE(3) +DEFINE_REVERSE(4) +DEFINE_REVERSE(5) +DEFINE_REVERSE(6) +DEFINE_REVERSE(7) +DEFINE_REVERSE(8) +#undef DEFINE_REVERSE + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/reverse_op_test.cc b/tensorflow/core/kernels/reverse_op_test.cc new file mode 100644 index 0000000000..d41c36e693 --- /dev/null +++ b/tensorflow/core/kernels/reverse_op_test.cc @@ -0,0 +1,101 @@ +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/tensor.h" +#include + +namespace tensorflow { +namespace { + +class ReverseOpTest : public OpsTestBase { + protected: + void MakeOp(DataType data_type) { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "Reverse") + .Input(FakeInput(data_type)) + .Input(FakeInput()) + .Attr("T", data_type) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(ReverseOpTest, Reverse_0) { + MakeOp(DT_FLOAT); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {true}); + ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({})); + expected.scalar() = expected.scalar().constant(3.f); + test::ExpectTensorEqual(expected, *output); +} + +TEST_F(ReverseOpTest, Reverse_234) { + MakeOp(DT_FLOAT); + + // Feed and run + // [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + AddInputFromArray(TensorShape({2, 3, 4}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23}); + AddInputFromArray(TensorShape({3}), {true, false, true}); + + ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor* params_tensor = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3, 4})); + // Should become + // [[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] + // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]] + test::FillValues( + &expected, {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 3, 2, 1, 0, 7, + 6, 5, 4, 11, 10, 9, 8}); + test::ExpectTensorEqual(expected, *params_tensor); +} + +TEST_F(ReverseOpTest, Reverse_1234) { + MakeOp(DT_FLOAT); + + // Feed and run + // [[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]] + AddInputFromArray(TensorShape({1, 2, 3, 4}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23}); + AddInputFromArray(TensorShape({4}), {true, true, false, true}); + + ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor* params_tensor = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 3, 4})); + // Should become + // [[[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] + // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]]] + test::FillValues( + &expected, {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 3, 2, 1, 0, 7, + 6, 5, 4, 11, 10, 9, 8}); + test::ExpectTensorEqual(expected, *params_tensor); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc new file mode 100644 index 0000000000..6673a700ef --- /dev/null +++ b/tensorflow/core/kernels/reverse_sequence_op.cc @@ -0,0 +1,170 @@ +// See docs in ../ops/array_ops.cc. + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow/core/kernels/reverse_sequence_op.h" + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +void CheckErrors(OpKernelContext* context, int seq_dim) { + const Tensor& input = context->input(0); + const Tensor& seq_lens = context->input(1); + + auto seq_lens_t = seq_lens.vec(); + + std::vector seq_lens_vec(seq_lens_t.size()); + + // Copy seq_len info down for validity checks + context->eigen_device().memcpyDeviceToHost( + seq_lens_vec.data(), seq_lens_t.data(), + sizeof(int64) * seq_lens_t.size()); + + OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim")); + OP_REQUIRES(context, seq_dim < input.dims(), + errors::InvalidArgument("seq_dim must be < input.dims()", "( ", + seq_dim, " vs. ", input.dims(), ")")); + + OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0), + errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ", + "(", seq_lens.NumElements(), " vs. ", + input.dim_size(seq_dim))); + + for (int d = 0; d < seq_lens_vec.size(); ++d) { + OP_REQUIRES(context, seq_lens_vec[d] >= 0, + errors::InvalidArgument("seq_lens(", d, ") < 0")); + OP_REQUIRES(context, seq_lens_vec[d] <= input.dim_size(seq_dim), + errors::InvalidArgument("seq_lens(", d, ") > input.dims(", + seq_dim, ")")); + } +} + +template <> +void CheckErrors(OpKernelContext* context, int seq_dim) { + const Tensor& input = context->input(0); + const Tensor& seq_lens = context->input(1); + + OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim")); + OP_REQUIRES(context, seq_dim < input.dims(), + errors::InvalidArgument("seq_dim must be < input.dims()", "( ", + seq_dim, " vs. ", input.dims(), ")")); + + OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0), + errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ", + "(", seq_lens.NumElements(), " vs. ", + input.dim_size(seq_dim))); +} + +template +class ReverseSequenceOp : public OpKernel { + public: + explicit ReverseSequenceOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& seq_lens = context->input(1); + + // Preliminary validation of sizes. + OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens.shape()), + errors::InvalidArgument("seq_lens input must be 1-dim, not ", + seq_lens.dims())); + + auto seq_lens_t = seq_lens.vec(); + + CheckErrors(context, seq_dim_); + + const int input_dims = input.dims(); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + +#define HANDLE_DIM(NDIM) \ + case NDIM: \ + functor::ReverseSequence::Compute( \ + context->eigen_device(), input.tensor(), seq_dim_, \ + seq_lens_t, output->tensor()); \ + break; + + switch (input_dims) { + HANDLE_DIM(2); + HANDLE_DIM(3); + HANDLE_DIM(4); + HANDLE_DIM(5); + + default: + OP_REQUIRES(context, false, + errors::InvalidArgument( + "ReverseSequenceOp : Unhandled input dimensions: ", + input_dims)); + } + } + + private: + int32 seq_dim_; + + TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp); +}; + +#define REGISTER_REVERSE_SEQUENCE(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ReverseSequence").Device(DEVICE_CPU).TypeConstraint("T"), \ + ReverseSequenceOp); + +TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE); + +#if GOOGLE_CUDA + +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T, Dims) \ + template <> \ + void ReverseSequence::Compute( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + int32 seq_dim, TTypes::ConstVec seq_lens, \ + typename TTypes::Tensor output); \ + extern template struct ReverseSequence; + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPEC(T, 2); \ + DECLARE_GPU_SPEC(T, 3); \ + DECLARE_GPU_SPEC(T, 4); \ + DECLARE_GPU_SPEC(T, 5); + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); + +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_REVERSE_SEQUENCE_GPU(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ReverseSequence").Device(DEVICE_GPU).TypeConstraint("T"), \ + ReverseSequenceOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_GPU); + +#undef REGISTER_REVERSE_SEQUENCE_GPU + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reverse_sequence_op.h b/tensorflow/core/kernels/reverse_sequence_op.h new file mode 100644 index 0000000000..d1dd572dcb --- /dev/null +++ b/tensorflow/core/kernels/reverse_sequence_op.h @@ -0,0 +1,56 @@ +#ifndef TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_ +#define TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_ +// Generator definition for ReverseSequenceOp, must be compilable by nvcc. + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +namespace generator { + +template +class ReverseGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + ReverseGenerator(typename TTypes::ConstTensor input, int32 seq_dim, + TTypes::ConstVec seq_lengths) + : input_(input), seq_dim_(seq_dim), seq_lengths_(seq_lengths) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + operator()(const Eigen::array& coords) const { + Eigen::array new_coords = coords; + if (coords[seq_dim_] < seq_lengths_(coords[0])) { + new_coords[seq_dim_] = seq_lengths_(coords[0]) - coords[seq_dim_] - 1; + } + + return input_(new_coords); + } + + private: + typename TTypes::ConstTensor input_; + int32 seq_dim_; + TTypes::ConstVec seq_lengths_; +}; + +} // namespace generator + +namespace functor { + +template +struct ReverseSequence { + EIGEN_ALWAYS_INLINE static void Compute( + const Device& d, typename TTypes::ConstTensor input, + int32 seq_dim, TTypes::ConstVec seq_lengths, + typename TTypes::Tensor output) { + generator::ReverseGenerator generator(input, seq_dim, seq_lengths); + output.device(d) = input.generate(generator); + } +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_ diff --git a/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc b/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc new file mode 100644 index 0000000000..7b5d533026 --- /dev/null +++ b/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc @@ -0,0 +1,26 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/reverse_sequence_op.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +#define DEFINE_GPU_SPEC(T, dims) \ + template class generator::ReverseGenerator; \ + template struct functor::ReverseSequence; + +#define DEFINE_GPU_SPECS(T) \ + DEFINE_GPU_SPEC(T, 2); \ + DEFINE_GPU_SPEC(T, 3); \ + DEFINE_GPU_SPEC(T, 4); \ + DEFINE_GPU_SPEC(T, 5); + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/save_op.cc b/tensorflow/core/kernels/save_op.cc new file mode 100644 index 0000000000..71a15c643e --- /dev/null +++ b/tensorflow/core/kernels/save_op.cc @@ -0,0 +1,81 @@ +// See docs in ../ops/io_ops.cc +#include "tensorflow/core/kernels/io.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/util/tensor_slice_writer.h" + +namespace tensorflow { + +class SaveOp : public OpKernel { + public: + explicit SaveOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + SaveTensors(context, &checkpoint::CreateTableTensorSliceBuilder, false); + } +}; + +REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp); + +class SaveSlicesOp : public OpKernel { + public: + explicit SaveSlicesOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + SaveTensors(context, &checkpoint::CreateTableTensorSliceBuilder, true); + } +}; + +REGISTER_KERNEL_BUILDER(Name("SaveSlices").Device(DEVICE_CPU), SaveSlicesOp); + +class ShardedFilenameOp : public OpKernel { + public: + explicit ShardedFilenameOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + static const char* input_names[3] = {"basename", "shard", "num_shards"}; + for (int i = 0; i < ctx->num_inputs(); ++i) { + OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()), + errors::InvalidArgument( + input_names[i], " must be a scalar, got shape ", + ctx->input(i).shape().ShortDebugString())); + } + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); + out->scalar()() = strings::Printf( + "%s-%05d-of-%05d", ctx->input(0).scalar()().c_str(), + ctx->input(1).scalar()(), ctx->input(2).scalar()()); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ShardedFilename").Device(DEVICE_CPU), + ShardedFilenameOp); + +class ShardedFilespecOp : public OpKernel { + public: + explicit ShardedFilespecOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + static const char* input_names[2] = {"basename", "num_shards"}; + for (int i = 0; i < ctx->num_inputs(); ++i) { + OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()), + errors::InvalidArgument( + input_names[i], " must be a scalar, got shape ", + ctx->input(i).shape().ShortDebugString())); + } + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); + out->scalar()() = strings::Printf( + "%s-\?\?\?\?\?-of-%05d", ctx->input(0).scalar()().c_str(), + ctx->input(1).scalar()()); + } +}; +REGISTER_KERNEL_BUILDER(Name("ShardedFilespec").Device(DEVICE_CPU), + ShardedFilespecOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/save_op_test.cc b/tensorflow/core/kernels/save_op_test.cc new file mode 100644 index 0000000000..ee1ba492a6 --- /dev/null +++ b/tensorflow/core/kernels/save_op_test.cc @@ -0,0 +1,443 @@ +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/tensor_slice_reader.h" +#include + +namespace tensorflow { +namespace { + +class SaveOpTest : public OpsTestBase { + protected: + void MakeOp() { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "Save") + .Input(FakeInput()) + .Input(FakeInput()) + .Input(FakeInput( + {DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8, DT_QINT32})) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(SaveOpTest, Simple) { + const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple"); + const string tensornames[] = {"tensor_int", "tensor_float", "tensor_double", + "tensor_qint8", "tensor_qint32"}; + + MakeOp(); + // Add a file name + AddInput(TensorShape({}), + [&filename](int x) -> string { return filename; }); + + // Add the tensor names + AddInput(TensorShape({5}), + [&tensornames](int x) -> string { return tensornames[x]; }); + + // Add a 1-d integer tensor + AddInput(TensorShape({10}), [](int x) -> int32 { return x + 1; }); + + // Add a 2-d float tensor + AddInput(TensorShape({2, 4}), + [](int x) -> float { return static_cast(x) / 10; }); + + // Add a 2-d double tensor + AddInput(TensorShape({2, 4}), + [](int x) -> double { return static_cast(x) / 20; }); + + // Add a 2-d qint8 tensor + AddInput(TensorShape({3, 2}), + [](int x) -> qint8 { return *reinterpret_cast(&x); }); + + // Add a 2-d qint32 tensor + AddInput(TensorShape({2, 3}), [](int x) -> qint32 { + return *reinterpret_cast(&x) * qint8(2); + }); + + ASSERT_OK(RunOpKernel()); + + // Check that the checkpoint file is properly written + checkpoint::TensorSliceReader reader(filename, + checkpoint::OpenTableTensorSliceReader); + EXPECT_OK(reader.status()); + + // We expect to find all saved tensors + { + // The 1-d integer tensor + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type)); + TensorShape expected({10}); + EXPECT_TRUE(shape.IsSameSize(expected)); + EXPECT_EQ(DT_INT32, type); + + // We expect the tensor value to be correct. + TensorSlice s = TensorSlice::ParseOrDie("-"); + int data[10]; + std::fill_n(data, 10, 0); + EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data)); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(i + 1, data[i]); + } + } + + { + // The 2-d float tensor + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type)); + TensorShape expected({2, 4}); + EXPECT_TRUE(shape.IsSameSize(expected)); + EXPECT_EQ(DT_FLOAT, type); + + // We expect the tensor value to be correct. + TensorSlice s = TensorSlice::ParseOrDie("-:-"); + float data[8]; + std::fill_n(data, 8, 0); + EXPECT_TRUE(reader.CopySliceData("tensor_float", s, data)); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(static_cast(i) / 10, data[i]); + } + } + + { + // The 2-d double tensor + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type)); + TensorShape expected({2, 4}); + EXPECT_TRUE(shape.IsSameSize(expected)); + EXPECT_EQ(DT_DOUBLE, type); + + // We expect the tensor value to be correct. + TensorSlice s = TensorSlice::ParseOrDie("-:-"); + double data[8]; + std::fill_n(data, 8, 0); + EXPECT_TRUE(reader.CopySliceData("tensor_double", s, data)); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(static_cast(i) / 20, data[i]); + } + } + + { + // The 2-d qint8 tensor + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type)); + TensorShape expected({3, 2}); + EXPECT_TRUE(shape.IsSameSize(expected)); + EXPECT_EQ(DT_QINT8, type); + + // We expect the tensor value to be correct. + TensorSlice s = TensorSlice::ParseOrDie("-:-"); + qint8 data[6]; + EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data)); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(*reinterpret_cast(&i), data[i]); + } + } + + { + // The 2-d qint32 tensor + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type)); + TensorShape expected({2, 3}); + EXPECT_TRUE(shape.IsSameSize(expected)); + EXPECT_EQ(DT_QINT32, type); + + // We expect the tensor value to be correct. + TensorSlice s = TensorSlice::ParseOrDie("-:-"); + qint32 data[6]; + EXPECT_TRUE(reader.CopySliceData("tensor_qint32", s, data)); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(*reinterpret_cast(&i) * qint8(2), data[i]); + } + } +} + +class SaveSlicesOpTest : public OpsTestBase { + protected: + void MakeOp() { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "SaveSlices") + .Input(FakeInput()) + .Input(FakeInput()) + .Input(FakeInput()) + .Input(FakeInput( + {DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8, DT_QINT32})) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +// Here we save only slices. We restore them in a larger tensor and we check +// that the right slice is restored. It is quite tricky to check that the +// right slices are actually restored so instead we just check that +// CopySliceData() return true/false depending on the slice we ask for. +TEST_F(SaveSlicesOpTest, Slices) { + const string filename = io::JoinPath(testing::TmpDir(), "tensor_slices"); + const string tensornames[] = {"tensor_int", "tensor_float", "tensor_double", + "tensor_qint8", "tensor_qint32"}; + // Specifies that the data we save are slices of larger tensors. + // See core/framework/tensor_slice.h for the slice syntax. + const string tensorshapes[] = { + "10 -", // Full contents of a 10 element vector. + "2 4 -:0,2", // A 2x2 slice of a 2x4 tensor. + "2 4 0,1:2,2", // A 1x2 slice of a 2x4 tensor. + "3 2 -:-", // Full contents of a 3x2 tensor. + "2 3 1,1:2,1" // Another 1x1 slice of a2x3 tensor. + }; + + MakeOp(); + // Add a file name + AddInput(TensorShape({}), + [&filename](int x) -> string { return filename; }); + + // Add the tensor names + AddInput(TensorShape({5}), + [&tensornames](int x) -> string { return tensornames[x]; }); + + // Add the tensor shapes and slices + AddInput(TensorShape({5}), [&tensorshapes](int x) -> string { + return tensorshapes[x]; + }); + + // Add a 1-d integer tensor + AddInput(TensorShape({10}), [](int x) -> int32 { return x + 1; }); + + // Add a 2-d float tensor + AddInput(TensorShape({2, 2}), + [](int x) -> float { return static_cast(x) / 10; }); + + // Add a 2-d double tensor + AddInput(TensorShape({1, 2}), + [](int x) -> double { return static_cast(x) / 20; }); + + // Add a 2-d qint8 tensor + AddInput(TensorShape({3, 2}), + [](int x) -> qint8 { return *reinterpret_cast(&x); }); + + // Add a 2-d qint32 tensor + AddInput(TensorShape({1, 1}), [](int x) -> qint32 { + return *reinterpret_cast(&x) * qint8(2); + }); + + ASSERT_OK(RunOpKernel()); + + // Check that the checkpoint file is properly written + checkpoint::TensorSliceReader reader(filename, + checkpoint::OpenTableTensorSliceReader); + EXPECT_OK(reader.status()); + + // We expect to find all saved tensors + { + // The 1-d integer tensor + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type)); + TensorShape expected({10}); + EXPECT_TRUE(shape.IsSameSize(expected)); + EXPECT_EQ(DT_INT32, type); + + // We saved the full tensor so we should be able to read it all. + TensorSlice s = TensorSlice::ParseOrDie("-"); + int data[10]; + EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data)); + } + + { + // The 2-d float tensor + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type)); + TensorShape expected({2, 4}); + EXPECT_TRUE(shape.IsSameSize(expected)); + EXPECT_EQ(DT_FLOAT, type); + + // We saved the slice "-:0,2" so we should not be able to read the full + // tensor. + TensorSlice full_slice = TensorSlice::ParseOrDie("-:-"); + TensorSlice saved_slice = TensorSlice::ParseOrDie("-:0,2"); + float data[8]; + EXPECT_FALSE(reader.CopySliceData("tensor_float", full_slice, data)); + EXPECT_TRUE(reader.CopySliceData("tensor_float", saved_slice, data)); + } + + { + // The 2-d double tensor + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type)); + TensorShape expected({2, 4}); + EXPECT_TRUE(shape.IsSameSize(expected)); + EXPECT_EQ(DT_DOUBLE, type); + + // We saved the slice "0,1:2,2" so we should not be able to read the full + // tensor. + TensorSlice full_slice = TensorSlice::ParseOrDie("-:-"); + TensorSlice saved_slice = TensorSlice::ParseOrDie("0,1:2,2"); + double data[8]; + EXPECT_FALSE(reader.CopySliceData("tensor_double", full_slice, data)); + EXPECT_TRUE(reader.CopySliceData("tensor_double", saved_slice, data)); + } + + { + // The 2-d qint8 tensor + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type)); + TensorShape expected({3, 2}); + EXPECT_TRUE(shape.IsSameSize(expected)); + EXPECT_EQ(DT_QINT8, type); + + // We saved the full slice. + TensorSlice s = TensorSlice::ParseOrDie("-:-"); + qint8 data[6]; + EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data)); + } + + { + // The 2-d qint32 tensor + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type)); + TensorShape expected({2, 3}); + EXPECT_TRUE(shape.IsSameSize(expected)); + EXPECT_EQ(DT_QINT32, type); + + // We expect the tensor value to be correct. + TensorSlice s = TensorSlice::ParseOrDie("1,1:2,1"); + TensorSlice full_slice = TensorSlice::ParseOrDie("-:-"); + TensorSlice saved_slice = TensorSlice::ParseOrDie("1,1:2,1"); + qint32 data[6]; + EXPECT_FALSE(reader.CopySliceData("tensor_qint32", full_slice, data)); + EXPECT_TRUE(reader.CopySliceData("tensor_qint32", saved_slice, data)); + } +} + +class SaveOpSlices2Test : public OpsTestBase { + protected: + void MakeOp() { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "SaveSlices") + .Input(FakeInput()) + .Input(FakeInput()) + .Input(FakeInput()) + .Input(FakeInput({DT_INT32, DT_INT32, DT_FLOAT})) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(SaveOpSlices2Test, TwoSlices) { + const string filename = io::JoinPath(testing::TmpDir(), "three_slices"); + // We will save 2 slices of the tensor named "four_by_sixteen" which is 4x16, + // and one slice of the "small" tensor. + const string tensornames[] = {"four_by_sixteen", "four_by_sixteen", "small"}; + const string tensorshapes[] = { + // Slice specifications for the 2 slices of "four_by_sixteen" + "4 16 0,2:-", // 1st slice covers indices 0 and 1 in the first dim. + "4 16 2,2:-", // 2nd slice covers indices 2 and 3 in the first dim. + "" // We save the full "small" tensors. + }; + + MakeOp(); + // Add a file name + AddInput(TensorShape({}), + [&filename](int x) -> string { return filename; }); + + // Add the tensor names + AddInput(TensorShape({3}), + [&tensornames](int x) -> string { return tensornames[x]; }); + + // Add the tensor shapes and slices + AddInput(TensorShape({3}), [&tensorshapes](int x) -> string { + return tensorshapes[x]; + }); + + // Add an integer tensor for slice 0,2:- of a 4x16 tensor: It is 2x16. + AddInput(TensorShape({2, 16}), [](int x) -> int32 { return x + 1; }); + + // Add an integer tensor for slice 2,2:- of a 4x16 tensor: It is 2x16. + AddInput(TensorShape({2, 16}), + [](int x) -> int32 { return 10 * (x + 1); }); + + // Add a float tensor for "small" + AddInput(TensorShape({2, 4}), + [](int x) -> float { return static_cast(x) / 10; }); + + ASSERT_OK(RunOpKernel()); + + // Check that the checkpoint file is properly written + checkpoint::TensorSliceReader reader(filename, + checkpoint::OpenTableTensorSliceReader); + EXPECT_OK(reader.status()); + + { + // Reload the two slices of "four_by_sixteen" into that tensor. + Tensor reloaded(DT_INT32, {4, 16}); + + // We expect to find all slices + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("four_by_sixteen", &shape, &type)); + EXPECT_TRUE(shape.IsSameSize(reloaded.shape())); + EXPECT_EQ(type, reloaded.dtype()); + + // Reload the whole tensor. + EXPECT_TRUE(reader.CopySliceData("four_by_sixteen", + TensorSlice(reloaded.dims()), + reloaded.flat().data())); + + { + auto slice = reloaded.Slice(0, 2).flat(); + for (int i = 0; i < slice.size(); ++i) { + EXPECT_EQ(i + 1, slice(i)); + } + } + { + auto slice = reloaded.Slice(2, 4).flat(); + for (int i = 0; i < slice.size(); ++i) { + EXPECT_EQ(10 * (i + 1), slice(i)); + } + } + } + + { + // Reload the small float tensor. + Tensor reloaded(DT_FLOAT, {2, 4}); + + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("small", &shape, &type)); + EXPECT_TRUE(shape.IsSameSize(reloaded.shape())); + EXPECT_EQ(DT_FLOAT, reloaded.dtype()); + + EXPECT_TRUE(reader.CopySliceData("small", TensorSlice(reloaded.dims()), + reloaded.flat().data())); + + for (int64 i = 0; i < reloaded.NumElements(); ++i) { + EXPECT_EQ(static_cast(i) / 10, reloaded.flat().data()[i]); + } + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc new file mode 100644 index 0000000000..88fcc1bdcc --- /dev/null +++ b/tensorflow/core/kernels/scatter_op.cc @@ -0,0 +1,167 @@ +// See docs in ../ops/state_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +enum class UpdateOp { ASSIGN, ADD, SUB }; + +template +class ScatterUpdateOp : public OpKernel { + public: + // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, + // etc. here. Should we have the framework do some sort of + // integer promotion automatically, or should that be something + // that users have to do explicitly with a conversion operator + // in the graph? + explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* c) override { + if (use_exclusive_lock_) { + // Hold mutex while we apply updates + mutex_lock l(*c->input_ref_mutex(0)); + DoCompute(c); + } else { + DoCompute(c); + } + } + + private: + bool use_exclusive_lock_; + + // Check whether updates.shape = indices.shape + params.shape[1:] + static bool ValidShapes(const Tensor& params, const Tensor& updates, + const Tensor& indices) { + if (updates.dims() != indices.dims() + params.dims() - 1) return false; + for (int d = 0; d < indices.dims(); d++) { + if (updates.dim_size(d) != indices.dim_size(d)) { + return false; + } + } + for (int d = 1; d < params.dims(); d++) { + if (params.dim_size(d) != updates.dim_size(d - 1 + indices.dims())) { + return false; + } + } + return true; + } + + void DoCompute(OpKernelContext* c) { + Tensor Tparams = c->mutable_input(0, use_exclusive_lock_); + OP_REQUIRES(c, Tparams.IsInitialized(), + errors::FailedPrecondition("Null ref for params")); + const Tensor& Tindices = c->input(1); + const Tensor& Tupdates = c->input(2); + OP_REQUIRES( + c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()), + errors::InvalidArgument("params must be at least 1-D, got shape ", + Tparams.shape().ShortDebugString())); + OP_REQUIRES( + c, ValidShapes(Tparams, Tupdates, Tindices), + errors::InvalidArgument( + "Must have updates.shape = indices.shape + params.shape[1:], got ", + "updates.shape ", Tupdates.shape().ShortDebugString(), + ", indices.shape ", Tindices.shape().ShortDebugString(), + ", params.shape ", Tparams.shape().ShortDebugString())); + const Index N = Tindices.NumElements(); + + // We always return the input ref. + c->forward_ref_input_to_ref_output(0, 0); + + if (N > 0) { + const Index first_dim_size = Tparams.dim_size(0); + // Validate all the indices are in range + auto Tindices_vec = Tindices.flat(); + for (Index i = 0; i < N; i++) { + const Index index = Tindices_vec(i); + OP_REQUIRES(c, index >= 0 && index < first_dim_size, + errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in indices is out of range"))); + } + auto Tparams_flat = Tparams.flat_outer_dims(); + auto Tupdates_flat = + Tupdates.shaped({N, Tupdates.NumElements() / N}); + for (Index i = 0; i < N; i++) { + // Copy last Ndim-1 dimensions of Tupdates[i] to + // Tparams[Tindices[i]] + switch (op) { + case UpdateOp::ASSIGN: { + Tparams_flat.template chip<0>(Tindices_vec(i)) = + Tupdates_flat.template chip<0>(i); + break; + } + case UpdateOp::ADD: { + Tparams_flat.template chip<0>(Tindices_vec(i)) += + Tupdates_flat.template chip<0>(i); + break; + } + case UpdateOp::SUB: { + Tparams_flat.template chip<0>(Tindices_vec(i)) -= + Tupdates_flat.template chip<0>(i); + break; + } + } + } + } + } +}; + +#define REGISTER_SCATTER_UPDATE(type, index_type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ScatterUpdate") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + ScatterUpdateOp); + +#define REGISTER_SCATTER_UPDATE_INT32(type) REGISTER_SCATTER_UPDATE(type, int32) +#define REGISTER_SCATTER_UPDATE_INT64(type) REGISTER_SCATTER_UPDATE(type, int64) + +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_INT32); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_INT64); + +#undef REGISTER_SCATTER_UPDATE_INT64 +#undef REGISTER_SCATTER_UPDATE_INT32 +#undef REGISTER_SCATTER_UPDATE + +#define REGISTER_SCATTER_ADD(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("ScatterAdd") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + ScatterUpdateOp); + +#define REGISTER_SCATTER_ADD_INT32(type) REGISTER_SCATTER_ADD(type, int32) +#define REGISTER_SCATTER_ADD_INT64(type) REGISTER_SCATTER_ADD(type, int64) + +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_INT32); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_INT64); + +#undef REGISTER_SCATTER_ADD_INT32 +#undef REGISTER_SCATTER_ADD_INT64 +#undef REGISTER_SCATTER_ADD + +#define REGISTER_SCATTER_SUB(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("ScatterSub") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + ScatterUpdateOp); + +#define REGISTER_SCATTER_SUB_INT32(type) REGISTER_SCATTER_SUB(type, int32) +#define REGISTER_SCATTER_SUB_INT64(type) REGISTER_SCATTER_SUB(type, int64) + +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_SUB_INT32); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_SUB_INT64); + +#undef REGISTER_SCATTER_SUB_INT64 +#undef REGISTER_SCATTER_SUB_INT32 +#undef REGISTER_SCATTER_SUB + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc new file mode 100644 index 0000000000..8885f1edb3 --- /dev/null +++ b/tensorflow/core/kernels/scatter_op_test.cc @@ -0,0 +1,255 @@ +#include +#include +#include + +#include +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace { + +class ScatterUpdateOpTest : public OpsTestBase { + protected: + void MakeOp(DataType index_type) { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "ScatterUpdate") + .Input(FakeInput(DT_FLOAT_REF)) + .Input(FakeInput(index_type)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(ScatterUpdateOpTest, Simple_TwoD32) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5, 3}), + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + AddInputFromArray(TensorShape({3}), {0, 4, 2}); + AddInputFromArray(TensorShape({3, 3}), + {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); + ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3})); + test::FillValues(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001, + 10002, 0, 0, 0, 777, 778, 779}); + test::ExpectTensorEqual(expected, params_tensor); +} + +TEST_F(ScatterUpdateOpTest, Simple_Two64) { + MakeOp(DT_INT64); + + // Feed and run + AddInputFromArray(TensorShape({5, 3}), + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + AddInputFromArray(TensorShape({3}), {0, 4, 2}); + AddInputFromArray(TensorShape({3, 3}), + {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); + ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3})); + test::FillValues(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001, + 10002, 0, 0, 0, 777, 778, 779}); + test::ExpectTensorEqual(expected, params_tensor); +} + +TEST_F(ScatterUpdateOpTest, Simple_ZeroD) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5}), {0, 0, 0, 0, 0}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {101}); + ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_FLOAT, TensorShape({5})); + test::FillValues(&expected, {0, 0, 0, 101, 0}); + test::ExpectTensorEqual(expected, params_tensor); +} + +TEST_F(ScatterUpdateOpTest, Simple_OneD) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5}), {0, 0, 0, 0, 0}); + AddInputFromArray(TensorShape({3}), {0, 4, 2}); + AddInputFromArray(TensorShape({3}), {100, 101, 102}); + ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_FLOAT, TensorShape({5})); + test::FillValues(&expected, {100, 0, 102, 0, 101}); + test::ExpectTensorEqual(expected, params_tensor); +} + +TEST_F(ScatterUpdateOpTest, HigherRank) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0}); + AddInputFromArray(TensorShape({2, 3}), {0, 4, 2, 1, 3, 6}); + AddInputFromArray(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60}); + ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_FLOAT, TensorShape({8})); + test::FillValues(&expected, {10, 40, 30, 50, 20, 0, 60, 0}); + test::ExpectTensorEqual(expected, params_tensor); +} + +TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5, 3}), + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + AddInputFromArray(TensorShape({3}), {0, 4, 99}); + AddInputFromArray(TensorShape({3, 3}), + {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("Index 99 at offset 2 in indices is out of range")) + << s; +} + +TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0}); + AddInputFromArray(TensorShape({1, 3}), {0, 4, 99}); + AddInputFromArray(TensorShape({3, 3}), + {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("Must have updates.shape = indices.shape + " + "params.shape[1:], got ")) + << s; +} + +TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5, 3}), + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + AddInputFromArray(TensorShape({3}), {0, 4, 2}); + AddInputFromArray( + TensorShape({3, 4}), + {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("Must have updates.shape = indices.shape + " + "params.shape[1:], got ")) + + << s; +} + +TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5, 3}), + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + AddInputFromArray(TensorShape({3}), {0, 4, 2}); + AddInputFromArray(TensorShape({2, 3}), + {100, 101, 102, 10000, 10001, 10002}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("Must have updates.shape = indices.shape + " + "params.shape[1:], got ")) + << s; +} + +class ScatterUpdateBM : public ScatterUpdateOpTest { + public: + virtual void TestBody() {} + void MakeBenchmarkOp(const char* op, DataType index_type) { + ASSERT_OK(NodeDefBuilder("myop", op) + .Input(FakeInput(DT_FLOAT_REF)) + .Input(FakeInput(index_type)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + } +}; + +template +static void BM_ScatterHelper(int iters, int embedding_size, const char* op) { + testing::StopTiming(); + const int kRows = 10000000 / embedding_size; + std::vector values; + for (int i = 0; i < kRows * embedding_size; i++) { + values.push_back(i); + } + const int kNumUpdates = 1000; + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + std::vector indices; + std::vector updates; + for (int i = 0; i < kNumUpdates; i++) { + indices.push_back(rnd.Uniform(kRows)); + for (int j = 0; j < embedding_size; j++) { + updates.push_back(i * 10 + j); + } + } + + ScatterUpdateBM bm; + bm.MakeBenchmarkOp(op, DataTypeToEnum::v()); + bm.AddInputFromArray(TensorShape({kRows, embedding_size}), values); + bm.AddInputFromArray(TensorShape({kNumUpdates}), indices); + bm.AddInputFromArray(TensorShape({kNumUpdates, embedding_size}), + updates); + testing::ItemsProcessed((static_cast(kNumUpdates) * embedding_size) * + iters); + testing::StartTiming(); + while (iters-- > 0) { + Status s = bm.RunOpKernel(); + } +} + +static void BM_ScatterUpdateInt32(int iters, int embedding_size) { + BM_ScatterHelper(iters, embedding_size, "ScatterUpdate"); +} +static void BM_ScatterUpdateInt64(int iters, int embedding_size) { + BM_ScatterHelper(iters, embedding_size, "ScatterUpdate"); +} + +static void BM_ScatterAddInt32(int iters, int embedding_size) { + BM_ScatterHelper(iters, embedding_size, "ScatterAdd"); +} +static void BM_ScatterAddInt64(int iters, int embedding_size) { + BM_ScatterHelper(iters, embedding_size, "ScatterAdd"); +} + +BENCHMARK(BM_ScatterUpdateInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); +BENCHMARK(BM_ScatterUpdateInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); + +BENCHMARK(BM_ScatterAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); +BENCHMARK(BM_ScatterAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc new file mode 100644 index 0000000000..2b6a8c5a88 --- /dev/null +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -0,0 +1,466 @@ +// See docs in ../ops/math_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +// This operator handles reducing segments along the first dimension. +// See core/ops/math_ops.cc for more details. +template +class SegmentReductionOp : public OpKernel { + public: + explicit SegmentReductionOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& segment_ids = context->input(1); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), + errors::InvalidArgument("segment_ids should be a vector.")); + const int64 num_indices = segment_ids.NumElements(); + OP_REQUIRES(context, num_indices == input.dim_size(0), + errors::InvalidArgument( + "segment_ids should be the same size as dimension 0 of" + " input.")); + + auto input_flat = input.flat_outer_dims(); + const int64 num_col = input_flat.dimension(1); + + const auto segment_vec = segment_ids.vec(); + // Note that the current implementation assumes that segment_vec values are + // sorted. + const Index output_rows = + num_indices > 0 ? segment_vec(num_indices - 1) + 1 : 0; + + TensorShape output_shape = input.shape(); + output_shape.set_dim(0, output_rows); + + // Note that we do not initialize the output buffer with a default value. + // We require that segment ids be sorted and cover all values (otherwise we + // return an error). + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + auto output_flat = output->flat_outer_dims(); + +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::DSizes dims_to_reduce; + dims_to_reduce[0] = 0; +#else + Eigen::IndexList> dims_to_reduce; +#endif + Index start = 0, end = 1; + // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it + // across threads. + Eigen::DSizes out_slice_shape(num_col); + while (end <= num_indices) { + if (end < num_indices) { + if (segment_vec(start) == segment_vec(end)) { + ++end; + continue; + } + // We have a new segment here. Verify that the segment ids grow by one + // each time, so that we cover every possible output value. + OP_REQUIRES( + context, segment_vec(start) + 1 == segment_vec(end), + errors::InvalidArgument("segment ids are not increasing by 1")); + } + + // Process segment [start, end) + const T* in_slice_ptr = &input_flat(start, 0); + typedef Eigen::TensorMap, + Eigen::Unaligned> OutT; + T* out_slice_ptr = &output_flat(segment_vec(start), 0); + OutT out_slice(out_slice_ptr, out_slice_shape); + // We don't use out_slice.device(context->egien_device) + // because these pieces of work are likely to be very small and + // the context switching overhead dwarfs any benefit we get from + // using another thread to do this work. + if (start == end - 1) { + typedef Eigen::TensorMap, + Eigen::Unaligned> InT; + InT in_slice(in_slice_ptr, out_slice_shape); + out_slice = in_slice; + } else { + Eigen::DSizes in_slice_shape(end - start, + num_col); + typedef Eigen::TensorMap, + Eigen::Unaligned> InT; + InT in_slice(in_slice_ptr, in_slice_shape); + + out_slice = in_slice.reduce(dims_to_reduce, Reducer()); + } + start = end; + ++end; + } + } +}; + +#define REGISTER_CPU_KERNELS(type, index_type) \ + REGISTER_KERNEL_BUILDER( \ + Name("SegmentSum") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + SegmentReductionOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SegmentMean") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + SegmentReductionOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SegmentProd") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + SegmentReductionOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SegmentMin") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + SegmentReductionOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SegmentMax") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + SegmentReductionOp>); + +#define REGISTER_CPU_KERNELS_ALL(type) \ + REGISTER_CPU_KERNELS(type, int32); \ + REGISTER_CPU_KERNELS(type, int64); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS_ALL); +#undef REGISTER_CPU_KERNELS +#undef REGISTER_CPU_KERNELS_ALL + +// Similar to SegmentReductionOp but can handle unsorted segment definitions and +// specifying size of output. +template +class UnsortedSegmentSumOp : public OpKernel { + public: + explicit UnsortedSegmentSumOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& data = context->input(0); + const Tensor& segment_ids = context->input(1); + const Tensor& num_segments = context->input(2); + + OP_REQUIRES( + context, TensorShapeUtils::IsLegacyScalar(num_segments.shape()), + errors::InvalidArgument("num_segments should be a scalar, not shape ", + num_segments.shape().ShortDebugString())); + + OP_REQUIRES(context, + TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()), + errors::InvalidArgument( + "data.shape = ", data.shape().ShortDebugString(), + " does not start with segment_ids.shape = ", + segment_ids.shape().ShortDebugString())); + + const auto segment_flat = segment_ids.flat(); + const int32 N = segment_flat.dimension(0); + const int32 output_rows = num_segments.scalar()(); + + if (N > 0) { + Eigen::Tensor m = segment_flat.maximum(); + OP_REQUIRES( + context, m() < output_rows, + errors::InvalidArgument("More segments found than output size")); + } + + TensorShape output_shape; + output_shape.AddDim(output_rows); + for (int i = segment_ids.dims(); i < data.dims(); i++) { + output_shape.AddDim(data.dim_size(i)); + } + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + auto output_flat = output->flat_outer_dims(); + output_flat.setZero(); + + if (data.NumElements() > 0) { + auto data_flat = data.shaped({N, data.NumElements() / N}); + for (int i = 0; i < N; ++i) { + output_flat.template chip<0>(segment_flat(i)) += + data_flat.template chip<0>(i); + } + } + } +}; + +#define REGISTER_CPU_UNSORTED_KERNELS(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + UnsortedSegmentSumOp); + +#define REGISTER_CPU_UNSORTED_KERNELS_ALL(type) \ + REGISTER_CPU_UNSORTED_KERNELS(type, int32); \ + REGISTER_CPU_UNSORTED_KERNELS(type, int64); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_UNSORTED_KERNELS_ALL); +#undef REGISTER_CPU_UNSORTED_KERNELS +#undef REGISTER_CPU_UNSORTED_KERNELS_ALL + +// Same as SegmentReductionOp but takes as input a "sparse" tensor, represented +// by two dense tensors, one containing the data, and the other containing +// indices into the data. +template +class SparseSegmentReductionOpBase : public OpKernel { + public: + explicit SparseSegmentReductionOpBase(OpKernelConstruction* context, + bool is_mean) + : OpKernel(context), is_mean_(is_mean) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& indices = context->input(1); + const Tensor& segment_ids = context->input(2); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()), + errors::InvalidArgument("indices should be a vector.")); + OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), + errors::InvalidArgument("segment_ids should be a vector.")); + + const int32 num_indices = indices.NumElements(); + OP_REQUIRES(context, num_indices == segment_ids.NumElements(), + errors::InvalidArgument( + "segment_ids and indices should have same size.")); + + auto input_flat = input.flat_outer_dims(); + + const auto indices_vec = indices.vec(); + const auto segment_vec = segment_ids.vec(); + // Note that the current implementation assumes that segment_vec values are + // sorted. + const int32 output_rows = + num_indices > 0 ? segment_vec(num_indices - 1) + 1 : 0; + + TensorShape output_shape = input.shape(); + output_shape.set_dim(0, output_rows); + + // Note that we do not initialize the output buffer with a default value. + // We require that segment ids be sorted and cover all values (otherwise we + // return an error). + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + if (num_indices == 0) return; + auto output_flat = output->flat_outer_dims(); + + int32 start = 0, end = 1; + while (end <= num_indices) { + if (end < num_indices) { + if (segment_vec(start) == segment_vec(end)) { + ++end; + continue; + } + // We have a new segment here. Verify that the segment ids grow by one + // each time, so that we cover every possible output value. + OP_REQUIRES( + context, segment_vec(start) + 1 == segment_vec(end), + errors::InvalidArgument("segment ids are not increasing by 1")); + } + + auto out = output_flat.template chip<0>(segment_vec(start)); +#define I(i) input_flat.template chip<0>(indices_vec(start + i)) + int num = end - start; + if (num == 1) { + out = I(0); + } else { + int r = num % 8; + T m = (is_mean_ && (num < 10)) ? num : 1; + switch (r) { + case 2: + out = (I(0) + I(1)) / m; + break; + case 3: + out = (I(0) + I(1) + I(2)) / m; + break; + case 4: + out = (I(0) + I(1) + I(2) + I(3)) / m; + break; + case 5: + out = (I(0) + I(1) + I(2) + I(3) + I(4)) / m; + break; + case 6: + out = (I(0) + I(1) + I(2) + I(3) + I(4) + I(5)) / m; + break; + case 7: + out = (I(0) + I(1) + I(2) + I(3) + I(4) + I(5) + I(6)) / m; + break; + case 0: + out = (I(0) + I(1) + I(2) + I(3) + I(4) + I(5) + I(6) + I(7)) / m; + r = 8; + break; + case 1: + out = + (I(0) + I(1) + I(2) + I(3) + I(4) + I(5) + I(6) + I(7) + I(8)) / + m; + r = 9; + break; + } + for (; r < num; r += 8) { + out += I(r) + I(r + 1) + I(r + 2) + I(r + 3) + I(r + 4) + I(r + 5) + + I(r + 6) + I(r + 7); + } +#undef I + if (is_mean_ && num >= 10) { + out = out / static_cast(num); + } + } + start = end; + ++end; + } + } + + private: + bool is_mean_; +}; + +template +class SparseSegmentReductionMeanOp + : public SparseSegmentReductionOpBase { + public: + explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context) + : SparseSegmentReductionOpBase(context, true /*is_mean*/) {} +}; + +template +class SparseSegmentReductionSumOp + : public SparseSegmentReductionOpBase { + public: + explicit SparseSegmentReductionSumOp(OpKernelConstruction* context) + : SparseSegmentReductionOpBase(context, false /*is_mean*/) {} +}; + +#define REGISTER_CPU_SPARSE_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseSegmentSum").Device(DEVICE_CPU).TypeConstraint("T"), \ + SparseSegmentReductionSumOp); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS); +#undef REGISTER_CPU_SPARSE_KERNELS + +#define REGISTER_CPU_SPARSE_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseSegmentMean").Device(DEVICE_CPU).TypeConstraint("T"), \ + SparseSegmentReductionMeanOp); +REGISTER_CPU_SPARSE_KERNELS(float); +REGISTER_CPU_SPARSE_KERNELS(double); +#undef REGISTER_CPU_SPARSE_KERNELS + +template +class SparseSegmentMeanGradOp : public OpKernel { + public: + explicit SparseSegmentMeanGradOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& indices = context->input(1); + const Tensor& segment_ids = context->input(2); + const Tensor& output_dim0 = context->input(3); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()), + errors::InvalidArgument("indices should be a vector.")); + OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), + errors::InvalidArgument("segment_ids should be a vector.")); + OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(output_dim0.shape()), + errors::InvalidArgument("output_dim0 should be a scalar.")); + + const int64 N = indices.NumElements(); + OP_REQUIRES(context, N == segment_ids.NumElements(), + errors::InvalidArgument( + "segment_ids and indices should have same size.")); + const int32 M = output_dim0.scalar()(); + + auto input_flat = input.flat_outer_dims(); + const auto indices_vec = indices.vec(); + const auto segment_vec = segment_ids.vec(); + + TensorShape output_shape = input.shape(); + output_shape.set_dim(0, M); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + if (M == 0 || N == 0) return; + + // Note that similar to SparseSegmentMean, we assume that segment_vec is + // already sorted and has non-negative values. + int num_segments = segment_vec(N - 1) + 1; + OP_REQUIRES(context, input.dim_size(0) == num_segments, + errors::InvalidArgument("Invalid number of segments")); + + // Compute scaling factors for input. + std::vector scaling(num_segments, 0.0); + for (int64 i = 0; i < N; ++i) { + scaling[segment_vec(i)] += 1; + } + for (int i = 0; i < scaling.size(); ++i) { + scaling[i] = 1.0 / std::max(scaling[i], 1.0); + } + + auto output_flat = output->flat_outer_dims(); + output_flat.setZero(); + std::vector is_modified(M, false); + + for (int64 i = 0; i < N; ++i) { + int output_idx = indices_vec(i); + int idx = segment_vec(i); + T scale = static_cast(scaling[idx]); + if (is_modified[output_idx]) { + if (scale == 1.0) { + output_flat.template chip<0>(output_idx) += + input_flat.template chip<0>(idx); + } else { + output_flat.template chip<0>(output_idx) += + input_flat.template chip<0>(idx) * scale; + } + } else { + if (scale == 1.0) { + output_flat.template chip<0>(output_idx) = + input_flat.template chip<0>(idx); + } else { + output_flat.template chip<0>(output_idx) = + input_flat.template chip<0>(idx) * scale; + } + } + is_modified[output_idx] = true; + } + } +}; + +#define REGISTER_CPU_SPARSE_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("SparseSegmentMeanGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + SparseSegmentMeanGradOp); + +REGISTER_CPU_SPARSE_KERNELS(float); +REGISTER_CPU_SPARSE_KERNELS(double); + +#undef REGISTER_CPU_SPARSE_KERNELS +} // namespace tensorflow diff --git a/tensorflow/core/kernels/segment_reduction_ops_test.cc b/tensorflow/core/kernels/segment_reduction_ops_test.cc new file mode 100644 index 0000000000..87647a21a8 --- /dev/null +++ b/tensorflow/core/kernels/segment_reduction_ops_test.cc @@ -0,0 +1,157 @@ +#include + +#include "tensorflow/core/public/session_options.h" + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" +#include +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" + +namespace tensorflow { + +template +static void BM_SegmentReduction(int iters, string reduction, Index num_rows, + Index num_cols, Index segment_size) { + testing::StopTiming(); + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + + // Create inputs + gtl::InlinedVector reduction_inputs; + TensorShape shape1({num_rows, num_cols}); + Tensor input1(DT_FLOAT, shape1); + reduction_inputs.push_back({nullptr, &input1}); + + TensorShape shape2({num_rows}); + Tensor input2(DataTypeToEnum::v(), shape2); + test::FillFn(&input2, [&num_rows, &segment_size](Index i) -> Index { + return std::min(i / segment_size, num_rows - 1); + }); + reduction_inputs.push_back({nullptr, &input2}); + + NodeDef reduction_node_def; + TF_CHECK_OK(NodeDefBuilder(reduction, reduction) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DataTypeToEnum::v())) + .Finalize(&reduction_node_def)); + Status status; + std::unique_ptr reduction_op(CreateOpKernel( + DEVICE_CPU, device.get(), cpu_allocator(), reduction_node_def, &status)); + OpKernelContext::Params params; + params.device = device.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &reduction_inputs; + params.op_kernel = reduction_op.get(); + params.output_alloc_attr = [&device, &reduction_op, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = + (reduction_op->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + + std::unique_ptr reduction_context( + new OpKernelContext(params)); + + reduction_op->Compute(reduction_context.get()); + TF_CHECK_OK(reduction_context->status()); + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + delete reduction_context->release_output(0).tensor; + reduction_op->Compute(reduction_context.get()); + } + int64 bytes_per_iter = + static_cast(num_rows * num_cols * sizeof(float)); + testing::BytesProcessed(bytes_per_iter * iters); +} + +#define BM_Reduce(O, R, C, S) \ + static void BM_Reduce_##O##_##R##_##C##_##S##_int32(int iters) { \ + BM_SegmentReduction(iters, #O, R, C, S); \ + } \ + static void BM_Reduce_##O##_##R##_##C##_##S##_int64(int iters) { \ + BM_SegmentReduction(iters, #O, R, C, S); \ + } \ + BENCHMARK(BM_Reduce_##O##_##R##_##C##_##S##_int32); \ + BENCHMARK(BM_Reduce_##O##_##R##_##C##_##S##_int64); + +#define BM_Reduce_Arg(R, C, S) \ + BM_Reduce(SegmentSum, R, C, S); \ + BM_Reduce(SegmentMean, R, C, S); + +BM_Reduce_Arg(64, 32, 1); +BM_Reduce_Arg(4096, 128, 1); + +BM_Reduce_Arg(16, 8, 2); +BM_Reduce_Arg(64, 32, 2); +BM_Reduce_Arg(4096, 32, 2); +BM_Reduce_Arg(4096, 128, 2); + +static void SparseSegmentMeanGradHelper(int iters, float uniqueness, int size) { + testing::StopTiming(); + RequireDefaultOps(); + Graph* g = new Graph(OpRegistry::Global()); + CHECK_LE(uniqueness, 1.0); + CHECK_GT(uniqueness, 0.0); + + const int kNumIndices = size; + Tensor indices(DT_INT32, TensorShape({kNumIndices})); + auto indices_flat = indices.flat(); + Tensor segments(DT_INT32, TensorShape({kNumIndices})); + auto segments_flat = segments.flat(); + + int kUniqueIndices = uniqueness * kNumIndices; + Tensor output_dim0(DT_INT32, TensorShape({})); + output_dim0.scalar()() = kUniqueIndices; + + for (int i = 0; i < kNumIndices; ++i) { + indices_flat(i) = (i * 31) % kUniqueIndices; + segments_flat(i) = i * .8; + } + + const int kDim1 = segments_flat(kNumIndices - 1) + 1; + const int kDim2 = 128; + Tensor input(DT_FLOAT, TensorShape({kDim1, kDim2})); + input.flat().setRandom(); + + Node* node; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "SparseSegmentMeanGrad") + .Input(test::graph::Constant(g, input)) + .Input(test::graph::Constant(g, indices)) + .Input(test::graph::Constant(g, segments)) + .Input(test::graph::Constant(g, output_dim0)) + .Attr("T", DT_FLOAT) + .Finalize(g, &node)); + + testing::UseRealTime(); + testing::BytesProcessed(static_cast(iters) * (kDim1 * kDim2) * + sizeof(float)); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} + +static void BM_SparseSegmentMeanGrad_Low(int iters, int size) { + return SparseSegmentMeanGradHelper(iters, 1.0, size); +} + +static void BM_SparseSegmentMeanGrad_High(int iters, int size) { + return SparseSegmentMeanGradHelper(iters, 0.01, size); +} + +BENCHMARK(BM_SparseSegmentMeanGrad_Low)->Arg(1000)->Arg(100000); +BENCHMARK(BM_SparseSegmentMeanGrad_High)->Arg(1000)->Arg(100000); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc new file mode 100644 index 0000000000..2abb183d1a --- /dev/null +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -0,0 +1,116 @@ +#include "tensorflow/core/kernels/sendrecv_ops.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +static string GetRendezvousKeyPrefix(const string& send_device, + const string& recv_device, + const uint64 send_device_incarnation, + const string& tensor_name) { + return strings::StrCat(send_device, ";", + strings::FpToString(send_device_incarnation), ";", + recv_device, ";", tensor_name); +} + +static string GetRendezvousKey(const string& key_prefix, + const FrameAndIter& frame_iter) { + return strings::StrCat(key_prefix, ";", frame_iter.frame_id, ":", + frame_iter.iter_id); +} + +SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string send_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device)); + string recv_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device)); + uint64 send_device_incarnation; + OP_REQUIRES_OK( + ctx, ctx->GetAttr("send_device_incarnation", + reinterpret_cast(&send_device_incarnation))); + string tensor_name; + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device, + send_device_incarnation, tensor_name); +} + +void SendOp::Compute(OpKernelContext* ctx) { + OP_REQUIRES( + ctx, ctx->rendezvous() != nullptr, + errors::Internal("Op kernel context needs to provide a rendezvous.")); + const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter()); + VLOG(2) << "Send " << key; + + // The device context may be passed between the Send/Recv + // boundary, so that the device context used to produce the Tensor + // is used when performing the copy on the recv side (which may be + // a different device). + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->input_alloc_attr(0); + Status s = + ctx->rendezvous()->Send(key, args, ctx->input(0), ctx->is_input_dead()); + ctx->SetStatus(s); +} + +REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp); +REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_GPU), SendOp); + +REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp); +REGISTER_KERNEL_BUILDER( + Name("_HostSend").Device(DEVICE_GPU).HostMemory("tensor"), SendOp); + +RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + string send_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device)); + string recv_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device)); + uint64 send_device_incarnation; + OP_REQUIRES_OK( + ctx, ctx->GetAttr("send_device_incarnation", + reinterpret_cast(&send_device_incarnation))); + string tensor_name; + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device, + send_device_incarnation, tensor_name); +} + +void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { + OP_REQUIRES( + ctx, ctx->rendezvous() != nullptr, + errors::Internal("Op kernel context needs to provide a rendezvous.")); + const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter()); + VLOG(2) << "Recv " << key; + + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->output_alloc_attr(0); + ctx->rendezvous()->RecvAsync( + key, args, [ctx, done](const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& val, bool is_dead) { + ctx->SetStatus(s); + if (s.ok()) { + // 'ctx' allocates the output tensor of the expected type. The + // runtime checks whether the tensor received here is the same type. + if (!is_dead) { + ctx->set_output(0, val); + } + *ctx->is_output_dead() = is_dead; + } + done(); + }); +} + +REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp); +REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_GPU), RecvOp); + +REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp); +REGISTER_KERNEL_BUILDER( + Name("_HostRecv").Device(DEVICE_GPU).HostMemory("tensor"), RecvOp); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/sendrecv_ops.h b/tensorflow/core/kernels/sendrecv_ops.h new file mode 100644 index 0000000000..b3f5703ccf --- /dev/null +++ b/tensorflow/core/kernels/sendrecv_ops.h @@ -0,0 +1,32 @@ +#ifndef TENSORFLOW_KERNELS_SENDRECV_OPS_H_ +#define TENSORFLOW_KERNELS_SENDRECV_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class SendOp : public OpKernel { + public: + explicit SendOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; + + private: + string key_prefix_; + + TF_DISALLOW_COPY_AND_ASSIGN(SendOp); +}; + +class RecvOp : public AsyncOpKernel { + public: + explicit RecvOp(OpKernelConstruction* ctx); + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + private: + string key_prefix_; + + TF_DISALLOW_COPY_AND_ASSIGN(RecvOp); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_KERNELS_SENDRECV_OPS_H_ diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc new file mode 100644 index 0000000000..60ba2e15f9 --- /dev/null +++ b/tensorflow/core/kernels/sequence_ops.cc @@ -0,0 +1,123 @@ +// See docs in ../ops/math_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +int32 GetValue(int32 v) { return v; } + +template +class RangeOp : public OpKernel { + public: + explicit RangeOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& start_in = context->input(0); + const Tensor& limit_in = context->input(1); + const Tensor& delta_in = context->input(2); + OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(start_in.shape()), + errors::InvalidArgument("start must be a scalar, not shape ", + start_in.shape().ShortDebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(limit_in.shape()), + errors::InvalidArgument("limit must be a scalar, not shape ", + limit_in.shape().ShortDebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(delta_in.shape()), + errors::InvalidArgument("delta must be a scalar, not shape ", + delta_in.shape().ShortDebugString())); + const int32 start = GetValue(start_in.scalar()()); + const int32 limit = GetValue(limit_in.scalar()()); + OP_REQUIRES(context, start <= limit, + errors::InvalidArgument("Requires start <= limit: ", start, "/", + limit)); + const int32 delta = GetValue(delta_in.scalar()()); + OP_REQUIRES(context, delta > 0, + errors::InvalidArgument("Requires delta > 0: ", delta)); + int32 size = (limit - start + delta - 1) / delta; + Tensor* out = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({size}), &out)); + auto flat = out->flat(); + int32 val = start; + for (int32 i = 0; i < size; ++i) { + flat(i) = T(val); + val += delta; + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("Range") + .Device(DEVICE_CPU) + .HostMemory("start") + .HostMemory("limit") + .HostMemory("delta") + .HostMemory("output"), + RangeOp); + +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("Range") + .Device(DEVICE_GPU) + .HostMemory("start") + .HostMemory("limit") + .HostMemory("delta") + .HostMemory("output"), + RangeOp); +#endif // GOOGLE_CUDA + +template +class LinSpaceOp : public OpKernel { + public: + explicit LinSpaceOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& start_in = context->input(0); + const Tensor& stop_in = context->input(1); + const Tensor& num_in = context->input(2); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(start_in.shape()), + errors::InvalidArgument("start must be a scalar, not shape ", + start_in.shape().ShortDebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(stop_in.shape()), + errors::InvalidArgument("stop must be a scalar, not shape ", + stop_in.shape().ShortDebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_in.shape()), + errors::InvalidArgument("num must be a scalar, not shape ", + num_in.shape().ShortDebugString())); + const T start = start_in.scalar()(); + const T stop = stop_in.scalar()(); + const int32 num = num_in.scalar()(); + OP_REQUIRES(context, num > 0, + errors::InvalidArgument("Requires num > 0: ", num)); + Tensor* out = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({num}), &out)); + auto flat = out->flat(); + if (num == 1) { + flat(0) = start; + } else { + const T step = (stop - start) / (num - 1); + for (int32 i = 0; i < num; ++i) flat(i) = start + step * i; + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("LinSpace") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .HostMemory("start") + .HostMemory("stop") + .HostMemory("num") + .HostMemory("output"), + LinSpaceOp); +REGISTER_KERNEL_BUILDER(Name("LinSpace") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .HostMemory("start") + .HostMemory("stop") + .HostMemory("num") + .HostMemory("output"), + LinSpaceOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc new file mode 100644 index 0000000000..7cb1da8983 --- /dev/null +++ b/tensorflow/core/kernels/shape_ops.cc @@ -0,0 +1,261 @@ +// See docs in ../ops/array_ops.cc. + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +class ShapeOp : public OpKernel { + public: + explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& inp = ctx->input(0); + const int rank = inp.dims(); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out)); + auto vec = out->vec(); + for (int i = 0; i < rank; ++i) vec(i) = inp.dim_size(i); + } + + bool IsExpensive() override { return false; } +}; +REGISTER_KERNEL_BUILDER(Name("Shape").Device(DEVICE_CPU).HostMemory("output"), + ShapeOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE_GPU) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ShapeOp) +TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Shape") + .Device(DEVICE_GPU) + .HostMemory("input") + .HostMemory("output") + .TypeConstraint("T"), + ShapeOp); + +class RankOp : public OpKernel { + public: + explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& inp = ctx->input(0); + const int rank = inp.dims(); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); + out->scalar()() = rank; + } + + bool IsExpensive() override { return false; } +}; +REGISTER_KERNEL_BUILDER(Name("Rank").Device(DEVICE_CPU).HostMemory("output"), + RankOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Rank") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("output"), \ + RankOp); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Rank") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("input") + .HostMemory("output"), + RankOp); + +class SizeOp : public OpKernel { + public: + explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& inp = ctx->input(0); + const int64 size = inp.NumElements(); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); + // TODO(josh11b): switch output to int64? + out->scalar()() = size; + } + + bool IsExpensive() override { return false; } +}; +REGISTER_KERNEL_BUILDER(Name("Size").Device(DEVICE_CPU).HostMemory("output"), + SizeOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("output"), \ + SizeOp); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Size") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("input") + .HostMemory("output"), + SizeOp); + +class ExpandDimsOp : public OpKernel { + public: + explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + int dim = ctx->input(1).flat()(0); + OP_REQUIRES( + ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()), + errors::InvalidArgument("Tried to expand dim index ", dim, + " for tensor with ", ctx->input(0).dims(), + " dimensions.")); + + auto existing_dims = ctx->input(0).shape().dim_sizes(); + std::vector new_shape(existing_dims.size()); + for (size_t i = 0; i < new_shape.size(); ++i) { + new_shape[i] = existing_dims[i]; + } + + // We emulate numpy's interpretation of the dim axis when + // -input.dims() >= dim <= input.dims(). + if (dim < 0) { + dim += existing_dims.size() + 1; + } + + // Clamp to the end if needed. + dim = std::min(dim, existing_dims.size()); + new_shape.emplace(new_shape.begin() + dim, 1); + const TensorShape output_shape(new_shape); + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output)); + if (!output->CopyFrom(ctx->input(0), output_shape)) { + // This should never happen, since the sizes of the input and output + // should always be the same (we only expand the dimension with 1). + ctx->SetStatus( + errors::Internal("Could not expand dimension with input shape ", + ctx->input(0).shape().DebugString(), + " and output shape ", output_shape.DebugString())); + } + } +}; +REGISTER_KERNEL_BUILDER(Name("ExpandDims").Device(DEVICE_CPU).HostMemory("dim"), + ExpandDimsOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("ExpandDims") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("dim"), \ + ExpandDimsOp); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +REGISTER_KERNEL_BUILDER(Name("ExpandDims") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("input") + .HostMemory("dim") + .HostMemory("output"), + ExpandDimsOp); + +class SqueezeOp : public OpKernel { + public: + explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + std::vector squeeze_dims; + OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims)); + squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end()); + } + + void Compute(OpKernelContext* ctx) override { + auto existing_dims = ctx->input(0).shape().dim_sizes(); + std::vector new_shape; + + std::unordered_set wrapped_squeeze_dims; + wrapped_squeeze_dims.reserve(squeeze_dims_.size()); + // Validate squeeze dims against the input. + for (int32 dim : squeeze_dims_) { + OP_REQUIRES( + ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()), + errors::InvalidArgument("Tried to squeeze dim index ", dim, + " for tensor with ", ctx->input(0).dims(), + " dimensions.")); + // If dim is < 0, we wrap around (-1 means the last element). + if (dim < 0) { + dim = existing_dims.size() + dim; + } + + wrapped_squeeze_dims.insert(dim); + } + + for (size_t i = 0; i < existing_dims.size(); ++i) { + auto existing_dim = existing_dims[i]; + + // If squeeze_set is non-empty, only squeeze those dimensions. + if (!wrapped_squeeze_dims.empty()) { + if (wrapped_squeeze_dims.count(i) > 0) { + OP_REQUIRES(ctx, existing_dim == 1, + errors::InvalidArgument("Tried to explicitly squeeze " + "dimension ", + i, " but dimension was not 1: ", + existing_dim)); + } else { + // This dimension is not being squeezed. + new_shape.push_back(existing_dim); + } + } else { + // Copy over all non-1-length dimensions. + if (existing_dim != 1) { + new_shape.push_back(existing_dim); + } + } + } + + const TensorShape output_shape(new_shape); + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output)); + if (!output->CopyFrom(ctx->input(0), output_shape)) { + // This should never happen, since the sizes of the input and + // output should always be the same. + ctx->SetStatus(errors::Internal("Could not squeeze input with shape ", + ctx->input(0).shape().DebugString(), + " and output shape ", + output_shape.DebugString())); + } + } + + private: + std::unordered_set squeeze_dims_; +}; + +REGISTER_KERNEL_BUILDER(Name("Squeeze").Device(DEVICE_CPU), SqueezeOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Squeeze").Device(DEVICE_GPU).TypeConstraint("T"), \ + SqueezeOp); +TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc new file mode 100644 index 0000000000..3477266d5d --- /dev/null +++ b/tensorflow/core/kernels/slice_op.cc @@ -0,0 +1,242 @@ +// See docs in ../ops/array_ops.cc. + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow/core/kernels/slice_op.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +namespace { + +gtl::InlinedVector IntTensorToInt64Vec(const Tensor& tensor) { + gtl::InlinedVector out; + if (tensor.dtype() == DT_INT32) { + for (int64 i = 0; i < tensor.NumElements(); ++i) { + out.push_back(tensor.flat()(i)); + } + } else if (tensor.dtype() == DT_INT64) { + for (int64 i = 0; i < tensor.NumElements(); ++i) { + out.push_back(tensor.flat()(i)); + } + } else { + LOG(FATAL) << "begin must be either int32 or int64"; + } + return out; +} + +} // namespace + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// Shared code that is not dependent on the type of T. We do this to reduce +// code size by not duplicating all this for all T (float, double, int32, etc.) +static void SharedValidation(OpKernelContext* context, + TensorShape* output_shape, bool* is_identity, + bool* slice_dim0, + gtl::InlinedVector* begin, + gtl::InlinedVector* size) { + const Tensor& input = context->input(0); + const Tensor& begin_tensor = context->input(1); + const Tensor& size_tensor = context->input(2); + + OP_REQUIRES( + context, TensorShapeUtils::IsLegacyVector(begin_tensor.shape()) && + TensorShapeUtils::IsLegacyVector(size_tensor.shape()) && + begin_tensor.NumElements() == input.dims() && + size_tensor.NumElements() == input.dims(), + errors::InvalidArgument( + "Expected begin and size arguments to be 1-D tensors of size ", + input.dims(), ", but got ", begin_tensor.NumElements(), " and ", + size_tensor.NumElements(), " instead.")); + + const int input_dims = input.dims(); + *begin = IntTensorToInt64Vec(begin_tensor); + *size = IntTensorToInt64Vec(size_tensor); + for (int i = 0; i < input_dims; ++i) { + if ((*size)[i] == -1) { + // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". + (*size)[i] = input.dim_size(i) - (*begin)[i]; + } + } + + *is_identity = true; + *slice_dim0 = true; + for (int i = 0; i < input_dims; ++i) { + int64 b = (*begin)[i]; + int64 s = (*size)[i]; + if (input.dim_size(i) == 0) { + OP_REQUIRES( + context, b == 0 && s == 0, + errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b, + ") and size[", i, "] == 0 ", "(got ", s, + ") when ", "input.dim_size(", i, ") == 0")); + } else { + OP_REQUIRES(context, 0 <= b && b <= input.dim_size(i), + errors::InvalidArgument("Expected begin[", i, "] in [0, ", + input.dim_size(i), "], but got ", b)); + OP_REQUIRES( + context, 0 <= s && b + s <= input.dim_size(i), + errors::InvalidArgument("Expected size[", i, "] in [0, ", + input.dim_size(i) - b, "], but ", "got ", s)); + } + output_shape->AddDim(s); + const bool take_all = (b == 0) && (s == input.dim_size(i)); + (*is_identity) &= take_all; + (*slice_dim0) &= (i == 0) || take_all; + } +} + +template +class SliceOp : public OpKernel { + public: + explicit SliceOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + TensorShape output_shape; + bool is_identity = true; + bool slice_dim0 = true; + gtl::InlinedVector begin; + gtl::InlinedVector size; + SharedValidation(context, &output_shape, &is_identity, &slice_dim0, &begin, + &size); + if (!context->status().ok()) return; + const Tensor& input = context->input(0); + if (is_identity) { + VLOG(1) << "Slice identity"; + context->set_output(0, input); + return; + } + + if (slice_dim0 && IsInnerDimsSizeAligned(input.shape())) { + VLOG(1) << "Slice dim 0: " << input.shape().DebugString(); + CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true. + context->set_output(0, input.Slice(begin[0], begin[0] + size[0])); + return; + } + + Tensor* result = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result)); + const int input_dims = input.dims(); + + if (output_shape.num_elements() > 0) { + if (std::is_same::value && input_dims == 2 && + DataTypeCanUseMemcpy(DataTypeToEnum::v())) { + auto input = context->input(0).tensor(); + auto output = result->tensor(); + // TODO(agarwal): Consider multi-threading this loop for cases where + // size[0] is very large. + for (int i = 0; i < size[0]; ++i) { + const int row = begin[0] + i; + if (i + 1 < size[0]) { + port::prefetch(&output(i + 1, 0)); + port::prefetch(&input(row + 1, begin[1])); + } + memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T)); + } + return; + } +#define HANDLE_DIM(NDIM) \ + if (input_dims == NDIM) { \ + HandleCase(context, begin, size, result); \ + return; \ + } + + HANDLE_DIM(1); + HANDLE_DIM(2); + HANDLE_DIM(3); + HANDLE_DIM(4); + HANDLE_DIM(5); + +#undef HANDLE_DIM + + OP_REQUIRES(context, false, errors::Unimplemented( + "SliceOp : Unhandled input dimensions")); + } + } + + private: + template + void HandleCase(OpKernelContext* context, const gtl::ArraySlice& begin, + const gtl::ArraySlice& size, Tensor* result) { + Eigen::DSizes indices; + Eigen::DSizes sizes; + for (int i = 0; i < NDIM; ++i) { + indices[i] = begin[i]; + sizes[i] = size[i]; + } + + functor::Slice()( + context->eigen_device(), result->tensor(), + context->input(0).tensor(), indices, sizes); + } +}; + +#define REGISTER_SLICE(type) \ + REGISTER_KERNEL_BUILDER(Name("Slice") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("size"), \ + SliceOp) + +TF_CALL_ALL_TYPES(REGISTER_SLICE); +REGISTER_SLICE(bfloat16); + +#undef REGISTER_SLICE + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T, NDIM) \ + template <> \ + void Slice::operator()( \ + const GPUDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input, \ + const Eigen::DSizes& indices, \ + const Eigen::DSizes& sizes); \ + extern template struct Slice; + +#define DECLARE_FOR_N(T) \ + DECLARE_GPU_SPEC(T, 1); \ + DECLARE_GPU_SPEC(T, 2); \ + DECLARE_GPU_SPEC(T, 3); \ + DECLARE_GPU_SPEC(T, 4); \ + DECLARE_GPU_SPEC(T, 5); + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N); +DECLARE_FOR_N(int32); + +#undef DECLARE_FOR_N +#undef DECLARE_GPU_SPEC +} // namespace functor + +#define REGISTER_GPU(type) \ + REGISTER_KERNEL_BUILDER(Name("Slice") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("size") \ + .TypeConstraint("Index"), \ + SliceOp) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); +REGISTER_GPU(int32); + +#undef REGISTER_GPU + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/slice_op.h b/tensorflow/core/kernels/slice_op.h new file mode 100644 index 0000000000..1b6bd9c112 --- /dev/null +++ b/tensorflow/core/kernels/slice_op.h @@ -0,0 +1,25 @@ +#ifndef TENSORFLOW_KERNELS_SLICE_OP_H_ +#define TENSORFLOW_KERNELS_SLICE_OP_H_ + +// Functor definition for SliceOp, must be compilable by nvcc. + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +template +struct Slice { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + const Eigen::DSizes& slice_indices, + const Eigen::DSizes& slice_sizes) { + output.device(d) = input.slice(slice_indices, slice_sizes); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_SLICE_OP_H_ diff --git a/tensorflow/core/kernels/slice_op_gpu.cu.cc b/tensorflow/core/kernels/slice_op_gpu.cu.cc new file mode 100644 index 0000000000..6e919b244c --- /dev/null +++ b/tensorflow/core/kernels/slice_op_gpu.cu.cc @@ -0,0 +1,31 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include + +#include "tensorflow/core/kernels/slice_op.h" + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Slice; \ + template struct functor::Slice; \ + template struct functor::Slice; \ + template struct functor::Slice; \ + template struct functor::Slice; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); +DEFINE_GPU_KERNELS(int32); + +#undef DEFINE_GPU_KERNELS + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/slice_op_test.cc b/tensorflow/core/kernels/slice_op_test.cc new file mode 100644 index 0000000000..27c78c6dc0 --- /dev/null +++ b/tensorflow/core/kernels/slice_op_test.cc @@ -0,0 +1,73 @@ +#include +#include +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +// For the benchmark, we set up two 2-dimensional tensors, each kDim1 x 'dim' +// in size, and concat them together along "concat_dimension" +template +static void SliceHelper(int iters, int size) { + testing::StopTiming(); + RequireDefaultOps(); + Graph* g = new Graph(OpRegistry::Global()); + DataType dt = DataTypeToEnum::v(); + int kDim = 100; + int kMaxSize = 15000; + CHECK_LT(size, kMaxSize); + + Tensor begin(DT_INT32, TensorShape({2})); + begin.flat()(0) = 10; + begin.flat()(1) = 10; + + Tensor sizes(DT_INT32, TensorShape({2})); + sizes.flat()(0) = kDim; + sizes.flat()(1) = size; + + Tensor input(dt, TensorShape({2 * kDim, kMaxSize})); + input.flat().setRandom(); + + Node* node; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Slice") + .Input(test::graph::Constant(g, input)) + .Input(test::graph::Constant(g, begin)) + .Input(test::graph::Constant(g, sizes)) + .Attr("T", dt) + .Finalize(g, &node)); + + testing::BytesProcessed(static_cast(iters) * kDim * size * sizeof(T)); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); + testing::UseRealTime(); +} + +static void BM_SliceFloat(int iters, int dim2) { + SliceHelper(iters, dim2); +} + +BENCHMARK(BM_SliceFloat)->Arg(100)->Arg(1000)->Arg(10000); + +static void BM_SliceBFloat16(int iters, int dim2) { + SliceHelper(iters, dim2); +} + +BENCHMARK(BM_SliceBFloat16)->Arg(100)->Arg(1000)->Arg(10000); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/softmax_op.cc b/tensorflow/core/kernels/softmax_op.cc new file mode 100644 index 0000000000..abe6331a4f --- /dev/null +++ b/tensorflow/core/kernels/softmax_op.cc @@ -0,0 +1,62 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/kernels/softmax_op.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class SoftmaxOp : public OpKernel { + public: + explicit SoftmaxOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& logits_in = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()), + errors::InvalidArgument("logits must be 2-dimensional")); + Tensor* softmax_out = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output(0, logits_in.shape(), &softmax_out)); + functor::SoftmaxFunctor functor; + functor(context->eigen_device(), logits_in.matrix(), + softmax_out->matrix()); + } +}; + +// Partial specialization for a CPUDevice, that uses the Eigen implementation +// from SoftmaxEigenImpl. +namespace functor { +template +struct SoftmaxFunctor { + void operator()(const CPUDevice& d, typename TTypes::ConstMatrix logits, + typename TTypes::Matrix softmax) { + SoftmaxEigenImpl::Compute(d, logits, softmax); + } +}; +} // namespace functor + +REGISTER_KERNEL_BUILDER(Name("Softmax") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + SoftmaxOp); +REGISTER_KERNEL_BUILDER(Name("Softmax") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + SoftmaxOp); + +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("Softmax") + .Device(DEVICE_GPU) + .TypeConstraint("T"), + SoftmaxOp); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/softmax_op.h b/tensorflow/core/kernels/softmax_op.h new file mode 100644 index 0000000000..69bd531b70 --- /dev/null +++ b/tensorflow/core/kernels/softmax_op.h @@ -0,0 +1,70 @@ +#ifndef TENSORFLOW_KERNELS_SOFTMAX_OP_H_ +#define TENSORFLOW_KERNELS_SOFTMAX_OP_H_ +// Functor definition for SoftmaxOp, must be compilable by nvcc. + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by SoftmaxOp to do the computations. +template +struct SoftmaxFunctor { + // Computes Softmax activation. + // + // logits: dim: batch_size, num_classes. + // softmax: dims: batch_size, num_classes. + void operator()(const Device& d, typename TTypes::ConstMatrix logits, + typename TTypes::Matrix softmax); +}; + +// Eigen code implementing SoftmaxFunctor::operator(). +// This code works for both CPU and GPU and is used by the functor +// specializations for both device types. +template +struct SoftmaxEigenImpl { + static void Compute(const Device& d, typename TTypes::ConstMatrix logits, + typename TTypes::Matrix softmax) { + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + +// These arrays are used to reduce along the class dimension, and broadcast +// the resulting value to all classes. +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); +#else + Eigen::IndexList > along_class; + Eigen::IndexList > depth_dim; + Eigen::IndexList > batch_by_one; + batch_by_one.set(0, batch_size); + Eigen::IndexList, int> one_by_class; + one_by_class.set(1, num_classes); +#endif + // NOTE(mdevin): If you modify this implementation please run + // the ImageNetSoftmaxFwd benchmark in core_ops_test.cc. + // + // softmax = exp(logits - max(logits along classes)); + softmax.device(d) = (logits - + logits.maximum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)).exp(); + // softmax = softmax / sum(softmax along classes); + softmax.device(d) = (softmax / + softmax.sum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_SOFTMAX_OP_H_ diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc new file mode 100644 index 0000000000..d5aaf9c364 --- /dev/null +++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc @@ -0,0 +1,31 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/softmax_op.h" + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// Partial specialization for a GPUDevice, that uses the Eigen implementation +// from SoftmaxEigenImpl. +namespace functor { +template +struct SoftmaxFunctor { + void operator()(const GPUDevice& d, typename TTypes::ConstMatrix logits, + typename TTypes::Matrix softmax) { + SoftmaxEigenImpl::Compute(d, logits, softmax); + } +}; +} // end namespace functor + +// Instantiate the GPU implementation for float. +template struct functor::SoftmaxFunctor; + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/softplus_op.cc b/tensorflow/core/kernels/softplus_op.cc new file mode 100644 index 0000000000..b5fb57d3c5 --- /dev/null +++ b/tensorflow/core/kernels/softplus_op.cc @@ -0,0 +1,97 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/softplus_op.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class SoftplusOp : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Softplus functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +template +class SoftplusGradOp + : public BinaryElementWiseOp> { + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): inputs that were passed to SoftplusOp() + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OP_REQUIRES(context, a.IsSameSize(g), + errors::InvalidArgument("g and a must be the same size")); + functor::SoftplusGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Softplus").Device(DEVICE_CPU).TypeConstraint("T"), \ + SoftplusOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SoftplusGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + SoftplusGradOp); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Softplus::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Softplus; \ + \ + template <> \ + void SoftplusGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor backprops); \ + extern template struct SoftplusGrad; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Softplus").Device(DEVICE_GPU).TypeConstraint("T"), \ + SoftplusOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SoftplusGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + SoftplusGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +#undef REGISTER_GPU_KERNELS + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/softplus_op.h b/tensorflow/core/kernels/softplus_op.h new file mode 100644 index 0000000000..3545a78246 --- /dev/null +++ b/tensorflow/core/kernels/softplus_op.h @@ -0,0 +1,46 @@ +#ifndef TENSORFLOW_KERNELS_SOFTPLUS_OP_H_ +#define TENSORFLOW_KERNELS_SOFTPLUS_OP_H_ +// Functor definition for SoftplusOp and SoftplusGradOp, must be compilable by +// nvcc. + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by SoftplusOp to do the computations. +template +struct Softplus { + // Computes Softplus activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + activations.device(d) = + (features > features.constant(30.f)) + .select(features, (features.exp() + features.constant(1.0f)).log()); + } +}; + +// Functor used by SoftplusGradOp to do the computations. +template +struct SoftplusGrad { + // Computes SoftplusGrad backprops. + // + // gradients: gradients backpropagated to the Softplus op. + // features: inputs that where passed to the Softplus op. + // backprops: gradients to backpropagate to the Softplus inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + backprops.device(d) = + gradients / ((-features).exp() + features.constant(1.0f)); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_SOFTPLUS_OP_H_ diff --git a/tensorflow/core/kernels/softplus_op_gpu.cu.cc b/tensorflow/core/kernels/softplus_op_gpu.cu.cc new file mode 100644 index 0000000000..7a974321a7 --- /dev/null +++ b/tensorflow/core/kernels/softplus_op_gpu.cu.cc @@ -0,0 +1,25 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include + +#include "tensorflow/core/kernels/softplus_op.h" + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// Definition of the GPU implementations declared in softplus_op.cc. +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Softplus; \ + template struct functor::SoftplusGrad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/sparse_concat_op.cc b/tensorflow/core/kernels/sparse_concat_op.cc new file mode 100644 index 0000000000..72c267a47d --- /dev/null +++ b/tensorflow/core/kernels/sparse_concat_op.cc @@ -0,0 +1,139 @@ +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +namespace tensorflow { + +template +class SparseConcatOp : public OpKernel { + public: + explicit SparseConcatOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("concat_dim", &concat_dim_)); + } + + void Compute(OpKernelContext* context) override { + OpInputList inds; + OP_REQUIRES_OK(context, context->input_list("indices", &inds)); + const int N = inds.size(); + for (int i = 0; i < N; i++) { + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(inds[i].shape()), + errors::InvalidArgument( + "Input indices should be a matrix but received shape ", + inds[i].shape().DebugString(), " at position ", i)); + } + + OpInputList vals; + OP_REQUIRES_OK(context, context->input_list("values", &vals)); + OP_REQUIRES(context, vals.size() == N, + errors::InvalidArgument("Expected ", N, " input values, got ", + vals.size())); + for (int i = 0; i < N; i++) { + OP_REQUIRES(context, TensorShapeUtils::IsVector(vals[i].shape()), + errors::InvalidArgument( + "Input values should be a vector but received shape ", + vals[i].shape().DebugString(), " at position ", i)); + } + + OpInputList shapes; + OP_REQUIRES_OK(context, context->input_list("shapes", &shapes)); + OP_REQUIRES(context, shapes.size() == N, + errors::InvalidArgument("Expected ", N, " input shapes, got ", + shapes.size())); + for (int i = 0; i < N; i++) { + OP_REQUIRES(context, TensorShapeUtils::IsVector(shapes[i].shape()), + errors::InvalidArgument( + "Input shapes should be a vector but received shape ", + shapes[i].shape().DebugString(), " at position ", i)); + } + + const TensorShape input_shape(shapes[0].vec()); + OP_REQUIRES( + context, concat_dim_ >= 0 && concat_dim_ < input_shape.dims(), + errors::InvalidArgument("Concat dimension must be between 0 and rank (", + input_shape.dims(), "), got ", concat_dim_)); + for (int i = 1; i < N; ++i) { + const TensorShape current_shape(shapes[i].vec()); + OP_REQUIRES(context, current_shape.dims() == input_shape.dims(), + errors::InvalidArgument( + "Ranks of all input tensors must match: expected ", + input_shape.dims(), " but got ", current_shape.dims(), + " at position ", i)); + for (int j = 0; j < input_shape.dims(); ++j) { + if (j != concat_dim_) { + OP_REQUIRES( + context, input_shape.dim_size(j) == current_shape.dim_size(j), + errors::InvalidArgument( + "Input shapes must match: expected ", input_shape.dim_size(j), + " for dimension ", j, " but got ", current_shape.dim_size(j), + " at position ", i)); + } + } + } + + // The input and output sparse tensors are assumed to be ordered along + // increasing dimension number. But in order for concat to work properly, + // order[0] must be concat_dim. So we will reorder the inputs to the + // concat ordering, concatenate, then reorder back to the standard order. + // We make a deep copy of the input tensors to ensure that the in-place + // reorder doesn't create race conditions for other ops that may be + // concurrently reading the indices and values tensors. + + gtl::InlinedVector std_order(input_shape.dims()); + std::iota(std_order.begin(), std_order.end(), 0); + + std::vector concat_order; + concat_order.reserve(input_shape.dims()); + concat_order.push_back(concat_dim_); + for (int j = 0; j < input_shape.dims(); ++j) { + if (j != concat_dim_) { + concat_order.push_back(j); + } + } + + std::vector sp_inputs; + for (int i = 0; i < N; ++i) { + const TensorShape current_shape(shapes[i].vec()); + sp_inputs.emplace_back(tensor::DeepCopy(inds[i]), + tensor::DeepCopy(vals[i]), current_shape, + std_order); + sp_inputs[i].Reorder(concat_order); + } + + sparse::SparseTensor concat = sparse::SparseTensor::Concat(sp_inputs); + concat.Reorder(std_order); + + context->set_output(0, concat.indices()); + context->set_output(1, concat.values()); + + Tensor* output_shape_out = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 2, TensorShape({concat.shape().dims()}), + &output_shape_out)); + auto output_shape = output_shape_out->vec(); + for (int j = 0; j < concat.shape().dims(); ++j) { + output_shape(j) = concat.shape().dim_size(j); + } + } + + private: + int concat_dim_; +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseConcat").Device(DEVICE_CPU).TypeConstraint("T"), \ + SparseConcatOp) + +TF_CALL_ALL_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS +} // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc new file mode 100644 index 0000000000..919e129ff8 --- /dev/null +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -0,0 +1,192 @@ +// See docs in ../ops/math_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/port.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +void PrefetchBlockNTA(const T& tensor, int si, int ei, int sj, int ej) { + for (int i = si; i < ei; ++i) { + for (int j = sj; j < ej; j = j + 16) { + port::prefetch(&tensor(i, j)); + } + } +} + +template +void PrefetchBlockT1(const T& tensor, int si, int ei, int sj, int ej) { + for (int i = si; i < ei; ++i) { + for (int j = sj; j < ej; j = j + 16) { + port::prefetch(&tensor(i, j)); + } + } +} + +struct Block { + Block(int sm, int em, int sk, int ek, int sn, int en) + : startm(sm), endm(em), startk(sk), endk(ek), startn(sn), endn(en) {} + + int startm; + int endm; + int startk; + int endk; + int startn; + int endn; +}; + +bool NextBlock(const int Bm, const int Bk, const int Bn, const int m_start, + const int m, const int k, const int n, const Block& b, + Block* next) { + *next = b; + if (b.endk < k) { + next->startk = b.endk; + next->endk = std::min(b.endk + Bk, k); + } else { + next->startk = 0; + next->endk = std::min(Bk, k); + if (b.endm < m) { + next->startm = b.endm; + next->endm = std::min(b.endm + Bm, m); + } else { + next->startm = m_start; + next->endm = std::min(m_start + Bm, m); + next->startn = b.endn; + next->endn = std::min(b.endn + Bn, n); + } + } + return next->startn == next->endn; +} + +class SparseMatMulOp : public OpKernel { + public: + explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& a = ctx->input(0); + const Tensor& b = ctx->input(1); + + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), + errors::InvalidArgument("a is not a matrix")); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), + errors::InvalidArgument("b is not a matrix")); + + auto left = a.matrix(); + auto right_mat = b.matrix(); + const int m = transpose_a_ ? left.dimension(1) : left.dimension(0); + const int k = transpose_a_ ? left.dimension(0) : left.dimension(1); + const int n = + transpose_b_ ? right_mat.dimension(0) : right_mat.dimension(1); + const int k2 = + transpose_b_ ? right_mat.dimension(1) : right_mat.dimension(0); + + OP_REQUIRES(ctx, k == k2, + errors::InvalidArgument("Matrix size incompatible: a: ", + a.shape().DebugString(), ", b: ", + b.shape().DebugString())); + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output)); + auto out = output->matrix(); + + if (!a_is_sparse_) { + // Fallback to Eigen contract. + // Note that we currently don't optimize the case where only right is + // sparse. That can generally be handled by tranposing the order of the + // matmul. + Eigen::array, 1> dim_pair; + dim_pair[0].first = transpose_a_ ? 0 : 1; + dim_pair[0].second = transpose_b_ ? 1 : 0; + out.device(ctx->template eigen_device()) = + left.contract(right_mat, dim_pair); + return; + } + typedef Eigen::Tensor Matrix; + std::unique_ptr right_tr_mat; + std::unique_ptr::ConstMatrix> right_tr_map; + if (transpose_b_) { + right_tr_mat.reset(new Matrix(k, n)); + Eigen::array perm({1, 0}); + right_tr_mat->device(ctx->template eigen_device()) = + right_mat.shuffle(perm); + right_tr_map.reset(new TTypes::ConstMatrix( + right_tr_mat->data(), right_tr_mat->dimensions())); + } + TTypes::ConstMatrix& right = + transpose_b_ ? *right_tr_map : right_mat; + + const bool transpose_a = transpose_a_; + + typedef Eigen::TensorMap, + Eigen::Unaligned> TensorMap; + typedef Eigen::TensorMap, + Eigen::Unaligned> ConstTensorMap; + typedef Eigen::DSizes DSizes; + const int Bm = 16; + const int Bk = 16; + const int Bn = 1024; + + auto work_shard = [m, n, k, transpose_a, Bm, Bk, Bn, &left, &right, &out]( + int64 start64, int64 end64) { + const int start = static_cast(start64); + const int end = static_cast(end64); + Block curr(start, std::min(start + Bm, end), 0, std::min(Bk, k), 0, + std::min(Bn, n)); + Block next(curr); + bool done = false; + for (int i = start; i < end; ++i) { + out.chip<0>(i).setZero(); + } + while (true) { + done = NextBlock(Bm, Bk, Bn, start, end, k, n, curr, &next); + + PrefetchBlockT1(right, curr.startk, curr.endk, curr.startn, curr.endn); + + // Process current block + for (int i = curr.startm; i < curr.endm; ++i) { + PrefetchBlockNTA(left, i, i + 1, curr.startk, curr.endk); + PrefetchBlockNTA(out, i, i + 1, curr.startn, curr.endn); + DSizes out_slice_shape(curr.endn - curr.startn); + TensorMap out_i(&out(i, curr.startn), out_slice_shape); + for (int j = curr.startk; j < curr.endk; ++j) { + const float l = transpose_a ? left(j, i) : left(i, j); + if (l == 0) continue; + ConstTensorMap right_j(&right(j, curr.startn), out_slice_shape); + out_i += right_j * l; + } + } + if (done) break; + curr = next; + } + }; + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, m, 2 * k * n, + work_shard); + } + + private: + bool transpose_a_; + bool transpose_b_; + bool a_is_sparse_; + bool b_is_sparse_; + TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp); +}; + +REGISTER_KERNEL_BUILDER(Name("SparseMatMul").Device(DEVICE_CPU), + SparseMatMulOp); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_matmul_op_test.cc b/tensorflow/core/kernels/sparse_matmul_op_test.cc new file mode 100644 index 0000000000..883d0d1224 --- /dev/null +++ b/tensorflow/core/kernels/sparse_matmul_op_test.cc @@ -0,0 +1,139 @@ +#include "tensorflow/core/framework/types.pb.h" +#include +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +random::PhiloxRandom philox(1, 1); +random::SimplePhilox rnd(&philox); + +void Sparsify(Tensor* t, float sparsity) { + const int64 N = t->NumElements(); + CHECK_LE(sparsity, 1); + if (sparsity <= 0) return; + auto flat = t->flat(); + static const uint32 K = 10000; + for (int64 i = 0; i < N; ++i) { + if (rnd.Uniform(K) < sparsity * K) { + flat(i) = 0; + } + } +} + +Node* SparseMatMulNode(Graph* g, Node* in0, Node* in1, bool transpose_a, + bool transpose_b, bool a_sparse, bool b_sparse) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "SparseMatMul") + .Input(in0) + .Input(in1) + .Attr("transpose_a", transpose_a) + .Attr("transpose_b", transpose_b) + .Attr("a_is_sparse", a_sparse) + .Attr("b_is_sparse", b_sparse) + .Finalize(g, &ret)); + return ret; +} + +static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d, float sparsity, + bool transpose_a, bool transpose_b, + bool a_sparse, bool b_sparse) { + a_sparse = a_sparse && (sparsity > 0); + b_sparse = b_sparse && (sparsity > 0); + + auto left_shape = transpose_a ? TensorShape({d, m}) : TensorShape({m, d}); + Tensor left(DataTypeToEnum::value, left_shape); + left.flat().setRandom(); + if (a_sparse) { + Sparsify(&left, sparsity); + } + + auto right_shape = transpose_b ? TensorShape({n, d}) : TensorShape({d, n}); + Tensor right(DataTypeToEnum::value, right_shape); + right.flat().setRandom(); + if (b_sparse) { + Sparsify(&right, sparsity); + } + + SparseMatMulNode(g, test::graph::Constant(g, left), + test::graph::Constant(g, right), transpose_a, transpose_b, + a_sparse, b_sparse); + return g; +} + +static Graph* SparseMatMul(int m, int n, int d, float sparsity, + bool transpose_a, bool transpose_b) { + Graph* g = new Graph(OpRegistry::Global()); + return SparseMatMulHelper(g, m, n, d, sparsity, transpose_a, transpose_b, + true, false); +} + +static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_a, + float sparsity_b) { + Graph* g = new Graph(OpRegistry::Global()); + if (sparsity_a == 0 && sparsity_b > 0) { + SparseMatMulHelper(g, m, n, d, sparsity_a, false, false, false, false); + SparseMatMulHelper(g, n, d, m, sparsity_b, true, true, true, false); + SparseMatMulHelper(g, m, d, n, sparsity_b, false, false, true, false); + } else { + SparseMatMulHelper(g, m, n, d, sparsity_a, false, true, true, false); + SparseMatMulHelper(g, d, n, m, sparsity_a, true, false, true, true); + SparseMatMulHelper(g, m, d, n, sparsity_b, false, false, true, false); + } + return g; +} + +#define BM_SPARSE(M, K, N, S) \ + static void BM_Sparse##_##M##_##K##_##N##_##S(int iters) { \ + testing::ItemsProcessed(static_cast(iters) * M * K * N * 2); \ + std::string label = strings::Printf("%d_%d_%d_%0.2f", M, K, N, S / 100.0); \ + testing::SetLabel(label); \ + test::Benchmark("cpu", SparseMatMul(M, N, K, S / 100.0, false, false)) \ + .Run(iters); \ + } \ + BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S); + +BM_SPARSE(2048, 2048, 2048, 0); +BM_SPARSE(2048, 2048, 2048, 1); +BM_SPARSE(2048, 2048, 2048, 85); + +BM_SPARSE(1024, 1024, 1024, 0); +BM_SPARSE(1024, 1024, 1024, 1); +BM_SPARSE(1024, 1024, 1024, 85); + +BM_SPARSE(256, 256, 256, 1); +BM_SPARSE(512, 512, 512, 1); + +#define BM_SPARSE_MULTI(M, K, N, S1, S2) \ + static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2(int iters) { \ + testing::ItemsProcessed(static_cast(iters) * M * K * N * 2 * 3); \ + std::string label = strings::Printf("%d_%d_%d_%0.2f_%0.2f", M, K, N, \ + S1 / 100.0, S2 / 100.0); \ + testing::SetLabel(label); \ + test::Benchmark("cpu", MultiSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0)) \ + .Run(iters); \ + } \ + BENCHMARK(BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2); + +BM_SPARSE_MULTI(512, 2140, 4096, 0, 82); +BM_SPARSE_MULTI(512, 4096, 2048, 83, 83); + +#define BM_SPARSE_TR(M, K, N, S, TA, TB) \ + static void BM_Sparse##_##M##_##K##_##N##_##S##_##TA##_##TB(int iters) { \ + testing::ItemsProcessed(static_cast(iters) * M * K * N * 2); \ + std::string label = \ + strings::Printf("%d_%d_%d_%d_%d_%0.2f", M, K, N, TA, TB, S / 100.0); \ + testing::SetLabel(label); \ + test::Benchmark("cpu", SparseMatMul(M, N, K, S / 100.0, TA, TB)) \ + .Run(iters); \ + } \ + BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S##_##TA##_##TB); + +BM_SPARSE_TR(2048, 2048, 2048, 1, true, false); +BM_SPARSE_TR(2048, 2048, 2048, 1, false, true); +BM_SPARSE_TR(2048, 2048, 2048, 1, true, true); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_reorder_op.cc b/tensorflow/core/kernels/sparse_reorder_op.cc new file mode 100644 index 0000000000..fd6824a4e2 --- /dev/null +++ b/tensorflow/core/kernels/sparse_reorder_op.cc @@ -0,0 +1,71 @@ +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +namespace tensorflow { + +template +class SparseReorderOp : public OpKernel { + public: + explicit SparseReorderOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input_ind = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_ind.shape()), + errors::InvalidArgument( + "Input indices should be a matrix but received shape", + input_ind.shape().DebugString())); + + const Tensor& input_val = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_val.shape()), + errors::InvalidArgument( + "Input values should be a vector but received shape", + input_val.shape().DebugString())); + + const Tensor& input_shape_in = context->input(2); + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()), + errors::InvalidArgument( + "Input shape should be a vector but received shape", + input_shape_in.shape().DebugString())); + + const TensorShape input_shape(input_shape_in.vec()); + + gtl::InlinedVector std_order(input_shape.dims()); + std::iota(std_order.begin(), std_order.end(), 0); + + // Check if the sparse tensor is already ordered correctly + sparse::SparseTensor input_sp(input_ind, input_val, input_shape, std_order); + + if (input_sp.IndicesValid()) { + context->set_output(0, input_sp.indices()); + context->set_output(1, input_sp.values()); + } else { + // Deep-copy the input Tensors, then reorder in-place + sparse::SparseTensor reordered_sp(tensor::DeepCopy(input_ind), + tensor::DeepCopy(input_val), + input_shape); + reordered_sp.Reorder(std_order); + context->set_output(0, reordered_sp.indices()); + context->set_output(1, reordered_sp.values()); + } + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseReorder").Device(DEVICE_CPU).TypeConstraint("T"), \ + SparseReorderOp) + +TF_CALL_ALL_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS +} // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_to_dense_op.cc b/tensorflow/core/kernels/sparse_to_dense_op.cc new file mode 100644 index 0000000000..47e91c134d --- /dev/null +++ b/tensorflow/core/kernels/sparse_to_dense_op.cc @@ -0,0 +1,129 @@ +// See core/ops/sparse_ops.cc for documentation. +// +// NOTE: the operations in this file only are suitable for execution +// on CPUs. + +#define EIGEN_USE_THREADS + +#include +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +namespace tensorflow { + +// Operator to convert sparse representations to dense. +template +class SparseToDense : public OpKernel { + public: + explicit SparseToDense(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* c) override { + // sparse_indices + const Tensor& indices = c->input(0); + OP_REQUIRES(c, indices.dims() <= 2, + errors::InvalidArgument( + "sparse_indices should be a scalar, vector, or matrix, " + "got shape ", + indices.shape().ShortDebugString())); + const int64 num_elems = indices.dims() > 0 ? indices.dim_size(0) : 1; + const int64 num_dims = indices.dims() > 1 ? indices.dim_size(1) : 1; + + // output_shape + const Tensor& output_shape = c->input(1); + OP_REQUIRES( + c, TensorShapeUtils::IsLegacyVector(output_shape.shape()), + errors::InvalidArgument("output_shape should be a vector, got shape ", + output_shape.shape().ShortDebugString())); + OP_REQUIRES(c, output_shape.NumElements() == num_dims, + errors::InvalidArgument( + "output_shape has incorrect number of elements: ", + output_shape.NumElements(), " should be: ", num_dims)); + + // sparse_values + const Tensor& sparse_values = c->input(2); + const int64 num_values = sparse_values.NumElements(); + OP_REQUIRES( + c, sparse_values.dims() == 0 || + (sparse_values.dims() == 1 && num_values == num_elems), + errors::InvalidArgument("sparse_values has incorrect shape ", + sparse_values.shape().ShortDebugString(), + ", should be [] or [", num_elems, "]")); + + // default_value + const Tensor& default_value = c->input(3); + OP_REQUIRES(c, TensorShapeUtils::IsScalar(default_value.shape()), + errors::InvalidArgument("default_value should be a scalar.")); + + auto output_shape_vec = output_shape.flat(); + Tensor* output = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, TensorShapeUtils::MakeShape( + output_shape_vec.data(), + output_shape_vec.size()), + &output)); + + TensorShape ix_shape({num_elems, num_dims}); + Tensor indices_shaped(DT_INT64, ix_shape); + if (indices.dtype() == DT_INT64) { + CHECK(indices_shaped.CopyFrom(indices, ix_shape)); + } else { + indices_shaped.matrix() = + indices.shaped(ix_shape.dim_sizes()).template cast(); + } + + // If we received a scalar, we'll need to create a new + // tensor with copies of the values as a vec. + // TODO(ebrevdo): find a way to avoid this temp allocation. + Tensor sparse_values_b; + + if (TensorShapeUtils::IsScalar(sparse_values.shape())) { + OP_REQUIRES_OK( + c, c->allocate_temp(DataTypeToEnum::value, + TensorShape({num_elems}), &sparse_values_b)); + sparse_values_b.vec().setConstant(sparse_values.scalar()()); + } else { + sparse_values_b = sparse_values; + } + + gtl::InlinedVector order(output->shape().dims()); + std::iota(order.begin(), order.end(), 0); // Assume order is correct + sparse::SparseTensor st(indices_shaped, sparse_values_b, output->shape(), + order); + + output->flat().setConstant(default_value.scalar()()); + OP_REQUIRES(c, st.template ToDense(output, false /* initialize */), + errors::InvalidArgument( + "Indices are not valid (out of bounds). Shape: ", + output->shape().DebugString())); + } +}; + +#define REGISTER_KERNELS(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("SparseToDense") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + SparseToDense); + +#define REGISTER_KERNELS_ALL(type) \ + REGISTER_KERNELS(type, int32); \ + REGISTER_KERNELS(type, int64); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS_ALL); +REGISTER_KERNELS_ALL(bool); +REGISTER_KERNELS_ALL(string); + +#undef REGISTER_KERNELS_ALL +#undef REGISTER_KERNELS + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_to_dense_op_test.cc b/tensorflow/core/kernels/sparse_to_dense_op_test.cc new file mode 100644 index 0000000000..e9800ccd68 --- /dev/null +++ b/tensorflow/core/kernels/sparse_to_dense_op_test.cc @@ -0,0 +1,283 @@ +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/tensor.h" +#include + +namespace tensorflow { + +namespace { + +class SparseToDenseTest : public OpsTestBase { + protected: + void SetUp() override { RequireDefaultOps(); } + + void MakeOp(int dim, DataType index_type, DataType value_type) { + ASSERT_OK(NodeDefBuilder("sparsetodense", "SparseToDense") + .Input(FakeInput(index_type)) + .Input(FakeInput(index_type)) + .Input(FakeInput(value_type)) + .Input(FakeInput(value_type)) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(SparseToDenseTest, OneD_OneValue) { + MakeOp(1, DT_INT32, DT_FLOAT); + + // sparse_indices + AddInputFromArray(TensorShape({3}), {1, 3, 4}); + // output_shape + AddInputFromArray(TensorShape({1}), {5}); + // sparse_values + AddInputFromArray(TensorShape({}), {2}); + // default_value + AddInputFromArray(TensorShape({}), {-2}); + + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, {5}); + test::FillValues(&expected, {-2, 2, -2, 2, 2}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(SparseToDenseTest, OneD_OneValue_int64_double) { + MakeOp(1, DT_INT64, DT_DOUBLE); + + // sparse_indices + AddInputFromArray(TensorShape({3}), {1, 3, 4}); + // output_shape + AddInputFromArray(TensorShape({1}), {5}); + // sparse_values + AddInputFromArray(TensorShape({}), {2}); + // default_value + AddInputFromArray(TensorShape({}), {-2}); + + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_DOUBLE, {5}); + test::FillValues(&expected, {-2, 2, -2, 2, 2}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(SparseToDenseTest, OneD_MultValues) { + MakeOp(1, DT_INT32, DT_FLOAT); + + // sparse_indices + AddInputFromArray({3}, {1, 3, 4}); + // output_shape + AddInputFromArray({1}, {5}); + // sparse_values + AddInputFromArray({3}, {3, 4, 5}); + // default_value + AddInputFromArray({}, {-2}); + + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, {5}); + test::FillValues(&expected, {-2, 3, -2, 4, 5}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(SparseToDenseTest, TwoD_OneValue) { + MakeOp(2, DT_INT32, DT_FLOAT); + + // sparse_indices + AddInputFromArray(TensorShape({3, 2}), {0, 1, 0, 2, 2, 3}); + // output_shape + AddInputFromArray(TensorShape({2}), {3, 4}); + // sparse_values + AddInputFromArray(TensorShape({}), {2}); + // default_value + AddInputFromArray(TensorShape({}), {-2}); + + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, {3, 4}); + expected.flat().setConstant(-2); + expected.tensor()(0, 1) = 2; + expected.tensor()(0, 2) = 2; + expected.tensor()(2, 3) = 2; + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(SparseToDenseTest, TwoD_MultValues) { + MakeOp(2, DT_INT32, DT_FLOAT); + + // sparse_indices + AddInputFromArray(TensorShape({3, 2}), {0, 1, 0, 2, 2, 3}); + // output_shape + AddInputFromArray(TensorShape({2}), {3, 4}); + // sparse_values + AddInputFromArray(TensorShape({3}), {3, 4, 5}); + // default_value + AddInputFromArray(TensorShape({}), {-2}); + + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, {3, 4}); + expected.flat().setConstant(-2); + expected.tensor()(0, 1) = 3; + expected.tensor()(0, 2) = 4; + expected.tensor()(2, 3) = 5; + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(SparseToDenseTest, ThreeD_OneValue) { + MakeOp(3, DT_INT32, DT_FLOAT); + + // sparse_indices + AddInputFromArray(TensorShape({3, 3}), {0, 1, 1, 0, 2, 0, 2, 3, 1}); + // output_shape + AddInputFromArray(TensorShape({3}), {3, 4, 2}); + // sparse_values + AddInputFromArray(TensorShape({}), {2}); + // default_value + AddInputFromArray(TensorShape({}), {-2}); + + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, {3, 4, 2}); + expected.flat().setConstant(-2); + expected.tensor()(0, 1, 1) = 2; + expected.tensor()(0, 2, 0) = 2; + expected.tensor()(2, 3, 1) = 2; + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(SparseToDenseTest, ThreeD_MultValues) { + MakeOp(3, DT_INT32, DT_FLOAT); + + // sparse_indices + AddInputFromArray(TensorShape({3, 3}), {0, 1, 1, 0, 2, 0, 2, 3, 1}); + // output_shape + AddInputFromArray(TensorShape({3}), {3, 4, 2}); + // sparse_values + AddInputFromArray(TensorShape({3}), {3, 4, 5}); + // default_value + AddInputFromArray(TensorShape({}), {-2}); + + ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, {3, 4, 2}); + expected.flat().setConstant(-2); + expected.tensor()(0, 1, 1) = 3; + expected.tensor()(0, 2, 0) = 4; + expected.tensor()(2, 3, 1) = 5; + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +} // namespace + +static int BM_Arg(int ndim, int n) { return (ndim * 1000000) + n; } +static int NDIM_from_arg(int bm_arg) { return bm_arg / 1000000; } +static int N_from_arg(int bm_arg) { return bm_arg % 1000000; } + +static void BM_SparseToDense(int iters, const int bm_arg) { + const int NDIM = NDIM_from_arg(bm_arg); + const int N = N_from_arg(bm_arg); + // TODO(zhifengc): Switch to use kernel_benchmark_testlib.h + tensorflow::testing::StopTiming(); + + const int IndexDim = (NDIM == 1) ? 0 : 1; + + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + + gtl::InlinedVector inputs; + + // Create a dense tensor with dims [1, ..., 1, N] + Tensor output_shape(DT_INT32, TensorShape({NDIM})); + Tensor sparse_indices(DT_INT32, TensorShape({N, NDIM})); + Tensor sparse_values(DT_FLOAT, TensorShape({N})); + Tensor default_value(DT_FLOAT, TensorShape({})); + auto output_shape_t = output_shape.vec(); + for (int d = 0; d < NDIM; ++d) { + output_shape_t(d) = (d == IndexDim) ? N : 3; + } + + auto sparse_indices_t = sparse_indices.matrix(); + for (int n = 0; n < N; ++n) { + for (int d = 0; d < NDIM; ++d) + sparse_indices_t(n, d) = (d == IndexDim) ? n : 0; + } + + for (auto* ptr : + {&sparse_indices, &output_shape, &sparse_values, &default_value}) { + inputs.push_back({nullptr, ptr}); + } + + NodeDef sparse_node_def; + TF_CHECK_OK(NodeDefBuilder("sparsetodense", "SparseToDense") + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(&sparse_node_def)); + + Status status; + std::unique_ptr op(CreateOpKernel( + DEVICE_CPU, device.get(), cpu_allocator(), sparse_node_def, &status)); + + OpKernelContext::Params params; + params.device = device.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs; + params.op_kernel = op.get(); + params.output_alloc_attr = [&device, &op, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + + std::unique_ptr sparse_context(new OpKernelContext(params)); + op->Compute(sparse_context.get()); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + delete sparse_context->release_output(0).tensor; + op->Compute(sparse_context.get()); + ASSERT_OK(sparse_context->status()); + } + tensorflow::testing::StopTiming(); + + // processing input, mainly + int64 bytes_per_iter = static_cast((N + N * NDIM) * sizeof(float)); + + tensorflow::testing::BytesProcessed(bytes_per_iter * iters); +} + +BENCHMARK(BM_SparseToDense) + ->Arg(BM_Arg(1, 10)) + ->Arg(BM_Arg(1, 100)) + ->Arg(BM_Arg(1, 1000)) + ->Arg(BM_Arg(1, 10000)) + ->Arg(BM_Arg(2, 10)) + ->Arg(BM_Arg(2, 100)) + ->Arg(BM_Arg(2, 1000)) + ->Arg(BM_Arg(2, 10000)) + ->Arg(BM_Arg(3, 10)) + ->Arg(BM_Arg(3, 100)) + ->Arg(BM_Arg(3, 1000)) + ->Arg(BM_Arg(3, 10000)) + ->Arg(BM_Arg(5, 10)) + ->Arg(BM_Arg(5, 100)) + ->Arg(BM_Arg(5, 1000)) + ->Arg(BM_Arg(5, 10000)); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc new file mode 100644 index 0000000000..f4f9ada000 --- /dev/null +++ b/tensorflow/core/kernels/split_op.cc @@ -0,0 +1,146 @@ +// See docs in ../ops/array_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/split_op.h" + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class SplitOp : public OpKernel { + public: + explicit SplitOp(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* context) override { + const int32 split_dim = context->input(0).flat()(0); + const int32 num_split = num_outputs(); + const Tensor& input = context->input(1); + const TensorShape& input_shape = input.shape(); + + OP_REQUIRES( + context, 0 <= split_dim && split_dim < input_shape.dims(), + errors::InvalidArgument("0 <= split_dim < number of input dimensions (", + input_shape.dims(), "), but got ", split_dim)); + + OP_REQUIRES( + context, num_split > 0, + errors::InvalidArgument( + "Number of ways to split should be > 0, but got ", num_split)); + + OP_REQUIRES(context, input_shape.dim_size(split_dim) % num_split == 0, + errors::InvalidArgument( + "Number of ways to split should evenly divide the split " + "dimension, but got split_dim ", + split_dim, " (size = ", input_shape.dim_size(split_dim), + ") ", "and num_split ", num_split)); + + // Special case 1: num_split == 1. Nothing to do. + if (num_split == 1) { + VLOG(1) << "Split identity"; + context->set_output(0, context->input(1)); + return; + } + + // Special case 2: split along the 1st dimension. We can share the + // underlying buffer. + // + // Apply this optimization conservatively: if input is aligned, + // the resulting tensors must be aligned. It's conservative + // because if the immediate consumer of the resulting tensors are + // not using eigen for computation, its perfectly fine to avoid + // the copying. + if ((split_dim == 0) && IsInnerDimsSizeAligned(input_shape)) { + VLOG(1) << "Slice dim 0: " << input_shape.DebugString(); + const int64 delta = input_shape.dim_size(0) / num_split; + for (int i = 0; i < num_split; ++i) { + context->set_output(i, input.Slice(i * delta, (i + 1) * delta)); + } + return; + } + + int32 prefix_dim_size = 1; + for (int i = 0; i < split_dim; ++i) { + prefix_dim_size *= input_shape.dim_size(i); + } + + int32 split_dim_size = input_shape.dim_size(split_dim); + + int32 suffix_dim_size = 1; + for (int i = split_dim + 1; i < input_shape.dims(); ++i) { + suffix_dim_size *= input_shape.dim_size(i); + } + + auto input_reshaped = + input.shaped({prefix_dim_size, split_dim_size, suffix_dim_size}); + + const int32 split_dim_output_size = split_dim_size / num_split; + TensorShape output_shape(input_shape); + output_shape.set_dim(split_dim, split_dim_output_size); + + Eigen::DSizes indices{0, 0, 0}; + Eigen::DSizes sizes{prefix_dim_size, split_dim_output_size, + suffix_dim_size}; + + for (int i = 0; i < num_split; ++i) { + Tensor* result = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(i, output_shape, &result)); + if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) { + Eigen::DSizes slice_indices; + Eigen::DSizes slice_sizes; + for (int j = 0; j < 3; ++j) { + slice_indices[j] = indices[j]; + slice_sizes[j] = sizes[j]; + } + + auto result_shaped = result->shaped( + {prefix_dim_size, split_dim_output_size, suffix_dim_size}); + + functor::Split()(context->eigen_device(), + result_shaped, input_reshaped, + slice_indices, slice_sizes); + } + indices[1] += split_dim_output_size; + } + } +}; + +#define REGISTER_SPLIT(type) \ + REGISTER_KERNEL_BUILDER(Name("Split") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("split_dim"), \ + SplitOp) + +TF_CALL_ALL_TYPES(REGISTER_SPLIT); + +#undef REGISTER_SPLIT + +#if GOOGLE_CUDA + +#define REGISTER_GPU(type) \ + REGISTER_KERNEL_BUILDER(Name("Split") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("split_dim"), \ + SplitOp) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); +#undef REGISTER_GPU + +#endif // GOOGLE_CUDA + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/split_op.h b/tensorflow/core/kernels/split_op.h new file mode 100644 index 0000000000..2572c77285 --- /dev/null +++ b/tensorflow/core/kernels/split_op.h @@ -0,0 +1,31 @@ +#ifndef TENSORFLOW_KERNELS_SPLIT_OP_H_ +#define TENSORFLOW_KERNELS_SPLIT_OP_H_ +// Functor definition for SplitOp, must be compilable by nvcc. + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +template +struct Split { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + const Eigen::DSizes& slice_indices, + const Eigen::DSizes& slice_sizes); +}; + +template +struct Split { + void operator()(const Eigen::ThreadPoolDevice& d, + typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + const Eigen::DSizes& slice_indices, + const Eigen::DSizes& slice_sizes); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_SPLIT_OP_H_ diff --git a/tensorflow/core/kernels/split_op_cpu.cc b/tensorflow/core/kernels/split_op_cpu.cc new file mode 100644 index 0000000000..b86deeb8fb --- /dev/null +++ b/tensorflow/core/kernels/split_op_cpu.cc @@ -0,0 +1,30 @@ +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/split_op.h" + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +void Split::operator()( + const Eigen::ThreadPoolDevice& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + const Eigen::DSizes& slice_indices, + const Eigen::DSizes& slice_sizes) { + if (output.size() < 131072) { + output = input.slice(slice_indices, slice_sizes); + } else { + output.device(d) = input.slice(slice_indices, slice_sizes); + } +} + +#define DEFINE_CPU_KERNELS(T) template struct Split; + +TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS) + +} // namespace functor +} // namespace tensorflow diff --git a/tensorflow/core/kernels/split_op_gpu.cu.cc b/tensorflow/core/kernels/split_op_gpu.cu.cc new file mode 100644 index 0000000000..f8931d6a89 --- /dev/null +++ b/tensorflow/core/kernels/split_op_gpu.cu.cc @@ -0,0 +1,31 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include + +#include "tensorflow/core/kernels/split_op.h" + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +void Split::operator()( + const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + const Eigen::DSizes& slice_indices, + const Eigen::DSizes& slice_sizes) { + output.device(d) = input.slice(slice_indices, slice_sizes); +} + +#define DEFINE_GPU_KERNELS(T) template struct Split; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.cc b/tensorflow/core/kernels/string_to_hash_bucket_op.cc new file mode 100644 index 0000000000..bd6fa47268 --- /dev/null +++ b/tensorflow/core/kernels/string_to_hash_bucket_op.cc @@ -0,0 +1,47 @@ +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +class StringToHashBucketOp : public OpKernel { + public: + explicit StringToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(context, context->input("string_tensor", &input_tensor)); + const auto& input_flat = input_tensor->flat(); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("output", input_tensor->shape(), + &output_tensor)); + auto output_flat = output_tensor->flat(); + + for (int i = 0; i < input_flat.size(); ++i) { + const uint64 input_hash = Hash64(input_flat(i)); + const uint64 bucket_id = input_hash % num_buckets_; + // The number of buckets is always in the positive range of int64 so is + // the resulting bucket_id. Casting the bucket_id from uint64 to int64 is + // safe. + output_flat(i) = static_cast(bucket_id); + } + } + + private: + int64 num_buckets_; + + TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp); +}; + +REGISTER_KERNEL_BUILDER(Name("StringToHashBucket").Device(DEVICE_CPU), + StringToHashBucketOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/string_to_number_op.cc b/tensorflow/core/kernels/string_to_number_op.cc new file mode 100644 index 0000000000..8d23a4fdf8 --- /dev/null +++ b/tensorflow/core/kernels/string_to_number_op.cc @@ -0,0 +1,71 @@ +// See docs in ../ops/parse_ops.cc. + +#include +#include + +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +static constexpr char kErrorMessage[] = + "StringToNumberOp could not correctly convert string: "; + +template +class StringToNumberOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* context) override { + // This is not a deep copy of the input tensor; they will share the same + // underlying storage. + const Tensor* input_tensor; + OP_REQUIRES_OK(context, context->input("string_tensor", &input_tensor)); + const auto& input_flat = input_tensor->flat(); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("output", input_tensor->shape(), + &output_tensor)); + auto output_flat = output_tensor->flat(); + + for (int i = 0; i < input_flat.size(); ++i) { + const char* s = input_flat(i).data(); + Convert(s, &output_flat(i), context); + } + } + + private: + void Convert(const char* s, OutputType* output_data, + OpKernelContext* context); +}; + +template <> +void StringToNumberOp::Convert(const char* s, float* output_data, + OpKernelContext* context) { + OP_REQUIRES(context, strings::safe_strtof(s, output_data), + errors::InvalidArgument(kErrorMessage, s)); +} + +template <> +void StringToNumberOp::Convert(const char* s, int32* output_data, + OpKernelContext* context) { + OP_REQUIRES(context, strings::safe_strto32(s, output_data), + errors::InvalidArgument(kErrorMessage, s)); +} + +// Registers the currently supported output types. +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER(Name("StringToNumber") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("out_type"), \ + StringToNumberOp) +REGISTER(float); +REGISTER(int32); +#undef REGISTER + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/summary_image_op.cc b/tensorflow/core/kernels/summary_image_op.cc new file mode 100644 index 0000000000..ba765f2e84 --- /dev/null +++ b/tensorflow/core/kernels/summary_image_op.cc @@ -0,0 +1,169 @@ +// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as +// inputs or outputs in various ways. + +// See docs in ../ops/summary_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/png/png_io.h" + +namespace tensorflow { + +class SummaryImageOp : public OpKernel { + public: + explicit SummaryImageOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("max_images", &max_images_)); + const TensorProto* proto; + OP_REQUIRES_OK(context, context->GetAttr("bad_color", &proto)); + OP_REQUIRES_OK(context, context->device()->MakeTensorFromProto( + *proto, AllocatorAttributes(), &bad_color_)); + OP_REQUIRES(context, bad_color_.dtype() == DT_UINT8, + errors::InvalidArgument("bad_color must be uint8, got ", + DataTypeString(bad_color_.dtype()))); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(bad_color_.shape()), + errors::InvalidArgument("bad_color must be a vector, got shape ", + bad_color_.shape().ShortDebugString())); + } + + void Compute(OpKernelContext* c) override { + const Tensor& tags = c->input(0); + const Tensor& tensor = c->input(1); + OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()), + errors::InvalidArgument("Tags must have be a scalar")); + OP_REQUIRES(c, tensor.dims() == 4 && + (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 || + tensor.dim_size(3) == 4), + errors::InvalidArgument( + "Tensor must be 4-D with last dim 1, 3, or 4, not ", + tensor.shape().DebugString())); + const string& base_tag = tags.scalar()(); + + const int batch_size = tensor.dim_size(0); + const int h = tensor.dim_size(1); + const int w = tensor.dim_size(2); + const int hw = h * w; // Compact these two dims for simplicity + const int depth = tensor.dim_size(3); + auto tensor_eigen = tensor.shaped({batch_size, hw, depth}); + + OP_REQUIRES(c, bad_color_.dim_size(0) >= depth, + errors::InvalidArgument( + "expected depth <= bad_color.size, got depth = ", depth, + ", bad_color.size = ", bad_color_.dim_size(0))); + auto bad_color_full = bad_color_.vec(); + typename TTypes::Vec bad_color(bad_color_full.data(), depth); + + // RGB (or gray or RGBA) is last dimension + Eigen::Tensor image(hw, depth); + + Summary s; + const int N = std::min(max_images_, batch_size); + for (int i = 0; i < N; ++i) { + Summary::Value* v = s.add_value(); + // The tag depends on the number of requested images (not the number + // produced.) + // + // Note that later on avisu uses "/" to figure out a consistent naming + // convention for display, so we append "/image" to guarantee that the + // image(s) won't be displayed in the global scope with no name. + if (max_images_ > 1) { + v->set_tag(strings::StrCat(base_tag, "/image/", i)); + } else { + v->set_tag(strings::StrCat(base_tag, "/image")); + } + + if (image.size()) { + typename TTypes::ConstMatrix values( + &tensor_eigen(i, 0, 0), + Eigen::DSizes(hw, depth)); + + // Rescale the image to uint8 range. + // + // We are trying to generate an RCG image from a float tensor. We do + // not have any info about the expected range of values in the tensor + // but the generated image needs to have all RGB values within [0, 255]. + // + // We use two different algorithms to generate these values. If the + // tensor has only positive values we scale them all by 255/max(values). + // If the tensor has both negative and positive values we scale them by + // the max of their absolute values and center them around 127. + // + // This works for most cases, but has the incovenient of not respecting + // the relative dynamic range across different instances of the tensor. + + // Compute min and max ignoring nonfinite pixels + float image_min = std::numeric_limits::infinity(); + float image_max = -image_min; + for (int i = 0; i < hw; i++) { + bool finite = true; + for (int j = 0; j < depth; j++) { + if (!std::isfinite(values(i, j))) { + finite = false; + break; + } + } + if (finite) { + for (int j = 0; j < depth; j++) { + float value = values(i, j); + image_min = std::min(image_min, value); + image_max = std::max(image_max, value); + } + } + } + + // Pick an affine transform into uint8 + const float kZeroThreshold = 1e-6; + float scale, offset; + if (image_min < 0) { + float max_val = std::max(std::abs(image_min), std::abs(image_max)); + scale = max_val < kZeroThreshold ? 0.0f : 127.0f / max_val; + offset = 128.0f; + } else { + scale = image_max < kZeroThreshold ? 0.0f : 255.0f / image_max; + offset = 0.0f; + } + + // Transform image, turning nonfinite values to bad_color + for (int i = 0; i < hw; i++) { + bool finite = true; + for (int j = 0; j < depth; j++) { + if (!std::isfinite(values(i, j))) { + finite = false; + break; + } + } + if (finite) { + image.chip<0>(i) = + (values.chip<0>(i) * scale + offset).cast(); + } else { + image.chip<0>(i) = bad_color; + } + } + } + + Summary::Image* si = v->mutable_image(); + si->set_height(h); + si->set_width(w); + si->set_colorspace(depth); + OP_REQUIRES(c, png::WriteImageToBuffer( + image.data(), w, h, w * depth, depth, 8, -1, + si->mutable_encoded_image_string(), nullptr), + errors::Internal("PNG encoding failed")); + } + + Tensor* summary_tensor = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor)); + CHECK(s.SerializeToString(&summary_tensor->scalar()())); + } + + private: + int64 max_images_; + Tensor bad_color_; +}; + +REGISTER_KERNEL_BUILDER(Name("ImageSummary").Device(DEVICE_CPU), + SummaryImageOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/summary_image_op_test.cc b/tensorflow/core/kernels/summary_image_op_test.cc new file mode 100644 index 0000000000..ddfeeffc0b --- /dev/null +++ b/tensorflow/core/kernels/summary_image_op_test.cc @@ -0,0 +1,141 @@ +#include +#include +#include + +#include +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace { + +static void EXPECT_SummaryMatches(const Summary& actual, + const string& expected_str) { + Summary expected; + CHECK(protobuf::TextFormat::ParseFromString(expected_str, &expected)); + EXPECT_EQ(expected.DebugString(), actual.DebugString()); +} + +// -------------------------------------------------------------------------- +// SummaryImageOp +// -------------------------------------------------------------------------- +class SummaryImageOpTest : public OpsTestBase { + protected: + void MakeOp(int max_images) { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "ImageSummary") + .Input(FakeInput()) + .Input(FakeInput()) + .Attr("max_images", max_images) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } + + void CheckAndRemoveEncodedImages(Summary* summary) { + for (int i = 0; i < summary->value_size(); ++i) { + Summary::Value* value = summary->mutable_value(i); + ASSERT_TRUE(value->has_image()) << "No image for value: " << value->tag(); + ASSERT_FALSE(value->image().encoded_image_string().empty()) + << "No encoded_image_string for value: " << value->tag(); + if (VLOG_IS_ON(2)) { + // When LOGGING, output the images to disk for manual inspection. + TF_CHECK_OK(WriteStringToFile( + Env::Default(), strings::StrCat("/tmp/", value->tag(), ".png"), + value->image().encoded_image_string())); + } + value->mutable_image()->clear_encoded_image_string(); + } + } +}; + +TEST_F(SummaryImageOpTest, ThreeGrayImagesOutOfFive4dInput) { + MakeOp(3 /* max images */); + + // Feed and run + AddInputFromArray(TensorShape({}), {"tag"}); + AddInputFromArray(TensorShape({5, 2, 1, 1}), + {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}); + ASSERT_OK(RunOpKernel()); + + // Check the output size. + Tensor* out_tensor = GetOutput(0); + ASSERT_EQ(0, out_tensor->dims()); + Summary summary; + ParseProtoUnlimited(&summary, out_tensor->scalar()()); + + CheckAndRemoveEncodedImages(&summary); + EXPECT_SummaryMatches(summary, R"( + value { tag: 'tag/image/0' image { width: 1 height: 2 colorspace: 1} } + value { tag: 'tag/image/1' image { width: 1 height: 2 colorspace: 1} } + value { tag: 'tag/image/2' image { width: 1 height: 2 colorspace: 1} } + )"); +} + +TEST_F(SummaryImageOpTest, OneGrayImage4dInput) { + MakeOp(1 /* max images */); + + // Feed and run + AddInputFromArray(TensorShape({}), {"tag"}); + AddInputFromArray(TensorShape({5 /*batch*/, 2, 1, 1 /*depth*/}), + {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}); + ASSERT_OK(RunOpKernel()); + + // Check the output size. + Tensor* out_tensor = GetOutput(0); + ASSERT_EQ(0, out_tensor->dims()); + Summary summary; + ParseProtoUnlimited(&summary, out_tensor->scalar()()); + + CheckAndRemoveEncodedImages(&summary); + EXPECT_SummaryMatches(summary, R"( + value { tag: 'tag/image' image { width: 1 height: 2 colorspace: 1} })"); +} + +TEST_F(SummaryImageOpTest, OneColorImage4dInput) { + MakeOp(1 /* max images */); + + // Feed and run + AddInputFromArray(TensorShape({}), {"tag"}); + AddInputFromArray( + TensorShape({1 /*batch*/, 5 /*rows*/, 2 /*columns*/, 3 /*depth*/}), + { + /* r0, c0, RGB */ 1.0, 0.1, 0.2, + /* r0, c1, RGB */ 1.0, 0.3, 0.4, + /* r1, c0, RGB */ 0.0, 1.0, 0.0, + /* r1, c1, RGB */ 0.0, 1.0, 0.0, + /* r2, c0, RGB */ 0.0, 0.0, 1.0, + /* r2, c1, RGB */ 0.0, 0.0, 1.0, + /* r3, c0, RGB */ 1.0, 1.0, 0.0, + /* r3, c1, RGB */ 1.0, 0.0, 1.0, + /* r4, c0, RGB */ 1.0, 1.0, 0.0, + /* r4, c1, RGB */ 1.0, 0.0, 1.0, + }); + ASSERT_OK(RunOpKernel()); + + // Check the output size. + Tensor* out_tensor = GetOutput(0); + ASSERT_EQ(0, out_tensor->dims()); + Summary summary; + ParseProtoUnlimited(&summary, out_tensor->scalar()()); + + CheckAndRemoveEncodedImages(&summary); + EXPECT_SummaryMatches(summary, R"( + value { tag: 'tag/image' image { width: 2 height: 5 colorspace: 3} })"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/summary_op.cc b/tensorflow/core/kernels/summary_op.cc new file mode 100644 index 0000000000..1c4be64b8b --- /dev/null +++ b/tensorflow/core/kernels/summary_op.cc @@ -0,0 +1,141 @@ +// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as +// inputs or outputs in various ways. + +// See docs in ../ops/summary_ops.cc. + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +template +class SummaryScalarOp : public OpKernel { + public: + explicit SummaryScalarOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* c) override { + const Tensor& tags = c->input(0); + const Tensor& values = c->input(1); + + OP_REQUIRES(c, tags.IsSameSize(values) || + (TensorShapeUtils::IsLegacyScalar(tags.shape()) && + TensorShapeUtils::IsLegacyScalar(values.shape())), + errors::InvalidArgument("tags and values not the same shape: ", + tags.shape().ShortDebugString(), " != ", + values.shape().ShortDebugString())); + auto Ttags = tags.flat(); + auto Tvalues = values.flat(); + Summary s; + for (int i = 0; i < Ttags.size(); i++) { + Summary::Value* v = s.add_value(); + v->set_tag(Ttags(i)); + v->set_simple_value(Tvalues(i)); + } + + Tensor* summary_tensor = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor)); + CHECK(s.SerializeToString(&summary_tensor->scalar()())); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ScalarSummary") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + SummaryScalarOp); +REGISTER_KERNEL_BUILDER(Name("ScalarSummary") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + SummaryScalarOp); + +class SummaryHistoOp : public OpKernel { + public: + // SummaryHistoOp could be extended to take a list of custom bucket + // boundaries as an option. + explicit SummaryHistoOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* c) override { + const Tensor& tags = c->input(0); + const Tensor& values = c->input(1); + const auto flat = values.flat(); + OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()), + errors::InvalidArgument("tags must be scalar")); + // Build histogram of values in "values" tensor + histogram::Histogram histo; + for (int64 i = 0; i < flat.size(); i++) { + float v = flat(i); + if (!std::isfinite(v)) { + c->SetStatus( + errors::OutOfRange("Nan in summary histogram for: ", name())); + break; + } + histo.Add(v); + } + + Summary s; + Summary::Value* v = s.add_value(); + v->set_tag(tags.scalar()()); + histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */); + + Tensor* summary_tensor = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor)); + CHECK(s.SerializeToString(&summary_tensor->scalar()())); + } +}; + +REGISTER_KERNEL_BUILDER(Name("HistogramSummary").Device(DEVICE_CPU), + SummaryHistoOp); + +struct HistogramResource : public ResourceBase { + histogram::ThreadSafeHistogram histogram; + + string DebugString() override { return "A historam summary. Stats ..."; } +}; + +class SummaryMergeOp : public OpKernel { + public: + explicit SummaryMergeOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* c) override { + Summary s; + std::unordered_set tags; + for (int input_num = 0; input_num < c->num_inputs(); input_num++) { + const Tensor& in = c->input(input_num); + auto in_vec = in.flat(); + for (int i = 0; i < in_vec.dimension(0); i++) { + const string& s_in = in_vec(i); + Summary summary_in; + if (!ParseProtoUnlimited(&summary_in, s_in)) { + c->SetStatus(errors::InvalidArgument( + "Could not parse one of the summary inputs")); + return; + } + + for (int v = 0; v < summary_in.value_size(); v++) { + if (!tags.insert(summary_in.value(v).tag()).second) { + c->SetStatus(errors::InvalidArgument( + strings::StrCat("Duplicate tag ", summary_in.value(v).tag(), + " found in summary inputs"))); + return; + } + *s.add_value() = summary_in.value(v); + } + } + } + + Tensor* summary_tensor = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor)); + CHECK(s.SerializeToString(&summary_tensor->scalar()())); + } +}; + +REGISTER_KERNEL_BUILDER(Name("MergeSummary").Device(DEVICE_CPU), + SummaryMergeOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/summary_op_test.cc b/tensorflow/core/kernels/summary_op_test.cc new file mode 100644 index 0000000000..fd271a6862 --- /dev/null +++ b/tensorflow/core/kernels/summary_op_test.cc @@ -0,0 +1,282 @@ +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +static void EXPECT_SummaryMatches(const Summary& actual, + const string& expected_str) { + Summary expected; + CHECK(protobuf::TextFormat::ParseFromString(expected_str, &expected)); + EXPECT_EQ(expected.DebugString(), actual.DebugString()); +} + +class SummaryScalarOpTest : public OpsTestBase { + protected: + void MakeOp(DataType dt) { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "ScalarSummary") + .Input(FakeInput()) + .Input(FakeInput(dt)) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(SummaryScalarOpTest, SimpleFloat) { + MakeOp(DT_FLOAT); + + // Feed and run + AddInputFromArray(TensorShape({3}), {"tag1", "tag2", "tag3"}); + AddInputFromArray(TensorShape({3}), {1.0, -0.73, 10000.0}); + ASSERT_OK(RunOpKernel()); + + // Check the output size. + Tensor* out_tensor = GetOutput(0); + ASSERT_EQ(0, out_tensor->dims()); + Summary summary; + ParseProtoUnlimited(&summary, out_tensor->scalar()()); + EXPECT_SummaryMatches(summary, R"( + value { tag: 'tag1' simple_value: 1.0 } + value { tag: 'tag2' simple_value: -0.73 } + value { tag: 'tag3' simple_value: 10000.0 } + )"); +} + +TEST_F(SummaryScalarOpTest, SimpleDouble) { + MakeOp(DT_DOUBLE); + + // Feed and run + AddInputFromArray(TensorShape({3}), {"tag1", "tag2", "tag3"}); + AddInputFromArray(TensorShape({3}), {1.0, -0.73, 10000.0}); + ASSERT_OK(RunOpKernel()); + + // Check the output size. + Tensor* out_tensor = GetOutput(0); + ASSERT_EQ(0, out_tensor->dims()); + Summary summary; + ParseProtoUnlimited(&summary, out_tensor->scalar()()); + EXPECT_SummaryMatches(summary, R"( + value { tag: 'tag1' simple_value: 1.0 } + value { tag: 'tag2' simple_value: -0.73 } + value { tag: 'tag3' simple_value: 10000.0 } + )"); +} + +TEST_F(SummaryScalarOpTest, Error_MismatchedSize) { + MakeOp(DT_FLOAT); + + // Feed and run + AddInputFromArray(TensorShape({2}), {"tag1", "tag2"}); + AddInputFromArray(TensorShape({3}), {1.0, -0.73, 10000.0}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()).contains("not the same shape")) << s; +} + +TEST_F(SummaryScalarOpTest, Error_WrongDimsTags) { + MakeOp(DT_FLOAT); + + // Feed and run + AddInputFromArray(TensorShape({2, 1}), {"tag1", "tag2"}); + AddInputFromArray(TensorShape({2}), {1.0, -0.73}); + Status s = RunOpKernel(); + EXPECT_TRUE( + StringPiece(s.ToString()).contains("tags and values not the same shape")) + << s; +} + +TEST_F(SummaryScalarOpTest, Error_WrongDimsValues) { + MakeOp(DT_FLOAT); + + // Feed and run + AddInputFromArray(TensorShape({2}), {"tag1", "tag2"}); + AddInputFromArray(TensorShape({2, 1}), {1.0, -0.73}); + Status s = RunOpKernel(); + EXPECT_TRUE( + StringPiece(s.ToString()).contains("tags and values not the same shape")) + << s; +} + +// -------------------------------------------------------------------------- +// SummaryHistoOp +// -------------------------------------------------------------------------- +class SummaryHistoOpTest : public OpsTestBase { + protected: + void MakeOp() { + ASSERT_OK(NodeDefBuilder("myop", "HistogramSummary") + .Input(FakeInput()) + .Input(FakeInput()) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(SummaryHistoOpTest, Simple) { + MakeOp(); + + // Feed and run + AddInputFromArray(TensorShape({}), {"taghisto"}); + AddInputFromArray(TensorShape({3, 2}), {0.1, -0.7, 4.1, 4., 5., 4.}); + ASSERT_OK(RunOpKernel()); + + // Check the output size. + Tensor* out_tensor = GetOutput(0); + ASSERT_EQ(0, out_tensor->dims()); + Summary summary; + ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ASSERT_EQ(summary.value_size(), 1); + EXPECT_EQ(summary.value(0).tag(), "taghisto"); + histogram::Histogram histo; + EXPECT_TRUE(histo.DecodeFromProto(summary.value(0).histo())); + EXPECT_EQ( + "Count: 6 Average: 2.7500 StdDev: 2.20\n" + "Min: -0.7000 Median: 3.9593 Max: 5.0000\n" + "------------------------------------------------------\n" + "[ -0.76, -0.69 ) 1 16.667% 16.667% ###\n" + "[ 0.093, 0.1 ) 1 16.667% 33.333% ###\n" + "[ 3.8, 4.2 ) 3 50.000% 83.333% ##########\n" + "[ 4.6, 5.1 ) 1 16.667% 100.000% ###\n", + histo.ToString()); +} + +TEST_F(SummaryHistoOpTest, Error_WrongDimsTags) { + MakeOp(); + + // Feed and run + AddInputFromArray(TensorShape({2, 1}), {"tag1", "tag2"}); + AddInputFromArray(TensorShape({2}), {1.0, -0.73}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()).contains("tags must be scalar")) << s; +} + +TEST_F(SummaryHistoOpTest, Error_TooManyTagValues) { + MakeOp(); + + // Feed and run + AddInputFromArray(TensorShape({2}), {"tag1", "tag2"}); + AddInputFromArray(TensorShape({2, 1}), {1.0, -0.73}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()).contains("tags must be scalar")) << s; +} + +// -------------------------------------------------------------------------- +// SummaryMergeOp +// -------------------------------------------------------------------------- +class SummaryMergeOpTest : public OpsTestBase { + protected: + void MakeOp(int num_inputs) { + ASSERT_OK(NodeDefBuilder("myop", "MergeSummary") + .Input(FakeInput(num_inputs)) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(SummaryMergeOpTest, Simple) { + MakeOp(1); + + // Feed and run + Summary s1; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "value { tag: \"tag1\" simple_value: 1.0 } " + "value { tag: \"tag2\" simple_value: -0.73 } ", + &s1)); + Summary s2; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "value { tag: \"tag3\" simple_value: 10000.0 }", &s2)); + Summary s3; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "value { tag: \"tag4\" simple_value: 11.0 }", &s3)); + + AddInputFromArray( + TensorShape({3}), + {s1.SerializeAsString(), s2.SerializeAsString(), s3.SerializeAsString()}); + ASSERT_OK(RunOpKernel()); + + // Check the output size. + Tensor* out_tensor = GetOutput(0); + ASSERT_EQ(0, out_tensor->dims()); + Summary summary; + ParseProtoUnlimited(&summary, out_tensor->scalar()()); + + EXPECT_SummaryMatches(summary, + "value { tag: \"tag1\" simple_value: 1.0 } " + "value { tag: \"tag2\" simple_value: -0.73 } " + "value { tag: \"tag3\" simple_value: 10000.0 }" + "value { tag: \"tag4\" simple_value: 11.0 }"); +} + +TEST_F(SummaryMergeOpTest, Simple_MultipleInputs) { + MakeOp(3); + + // Feed and run + Summary s1; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "value { tag: \"tag1\" simple_value: 1.0 } " + "value { tag: \"tag2\" simple_value: -0.73 } ", + &s1)); + Summary s2; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "value { tag: \"tag3\" simple_value: 10000.0 }", &s2)); + Summary s3; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "value { tag: \"tag4\" simple_value: 11.0 }", &s3)); + + AddInputFromArray(TensorShape({}), {s1.SerializeAsString()}); + AddInputFromArray(TensorShape({}), {s2.SerializeAsString()}); + AddInputFromArray(TensorShape({}), {s3.SerializeAsString()}); + ASSERT_OK(RunOpKernel()); + + // Check the output size. + Tensor* out_tensor = GetOutput(0); + ASSERT_EQ(0, out_tensor->dims()); + Summary summary; + ParseProtoUnlimited(&summary, out_tensor->scalar()()); + + EXPECT_SummaryMatches(summary, + "value { tag: \"tag1\" simple_value: 1.0 } " + "value { tag: \"tag2\" simple_value: -0.73 } " + "value { tag: \"tag3\" simple_value: 10000.0 }" + "value { tag: \"tag4\" simple_value: 11.0 }"); +} + +TEST_F(SummaryMergeOpTest, Error_MismatchedSize) { + MakeOp(1); + + // Feed and run + Summary s1; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "value { tag: \"tag1\" simple_value: 1.0 } " + "value { tag: \"tagduplicate\" simple_value: -0.73 } ", + &s1)); + Summary s2; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "value { tag: \"tagduplicate\" simple_value: 1.0 } ", &s2)); + AddInputFromArray(TensorShape({2}), + {s1.SerializeAsString(), s2.SerializeAsString()}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()).contains("Duplicate tag")) << s; +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/text_line_reader_op.cc b/tensorflow/core/kernels/text_line_reader_op.cc new file mode 100644 index 0000000000..51e4d6a2b8 --- /dev/null +++ b/tensorflow/core/kernels/text_line_reader_op.cc @@ -0,0 +1,99 @@ +// See docs in ../ops/io_ops.cc. + +#include +#include "tensorflow/core/framework/reader_op_kernel.h" +#include "tensorflow/core/kernels/reader_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/inputbuffer.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { + +class TextLineReader : public ReaderBase { + public: + TextLineReader(const string& node_name, int skip_header_lines, Env* env) + : ReaderBase(strings::StrCat("TextLineReader '", node_name, "'")), + skip_header_lines_(skip_header_lines), + env_(env), + line_number_(0) {} + + Status OnWorkStartedLocked() override { + line_number_ = 0; + RandomAccessFile* file = nullptr; + TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file)); + input_buffer_.reset(new io::InputBuffer(file, kBufferSize)); + for (; line_number_ < skip_header_lines_; ++line_number_) { + string line_contents; + Status status = input_buffer_->ReadLine(&line_contents); + if (errors::IsOutOfRange(status)) { + // We ignore an end of file error when skipping header lines. + // We will end up skipping this file. + return Status::OK(); + } + TF_RETURN_IF_ERROR(status); + } + return Status::OK(); + } + + Status OnWorkFinishedLocked() override { + input_buffer_.reset(nullptr); + return Status::OK(); + } + + Status ReadLocked(string* key, string* value, bool* produced, + bool* at_end) override { + Status status = input_buffer_->ReadLine(value); + ++line_number_; + if (status.ok()) { + *key = strings::StrCat(current_work(), ":", line_number_); + *produced = true; + return status; + } + if (errors::IsOutOfRange(status)) { // End of file, advance to the next. + *at_end = true; + return Status::OK(); + } else { // Some other reading error + return status; + } + } + + Status ResetLocked() override { + line_number_ = 0; + input_buffer_.reset(nullptr); + return ReaderBase::ResetLocked(); + } + + // TODO(josh11b): Implement serializing and restoring the state. Need + // to create TextLineReaderState proto to store ReaderBaseState, + // line_number_, and input_buffer_->Tell(). + + private: + enum { kBufferSize = 256 << 10 /* 256 kB */ }; + const int skip_header_lines_; + Env* const env_; + int64 line_number_; + std::unique_ptr input_buffer_; +}; + +class TextLineReaderOp : public ReaderOpKernel { + public: + explicit TextLineReaderOp(OpKernelConstruction* context) + : ReaderOpKernel(context) { + int skip_header_lines = -1; + OP_REQUIRES_OK(context, + context->GetAttr("skip_header_lines", &skip_header_lines)); + OP_REQUIRES(context, skip_header_lines >= 0, + errors::InvalidArgument("skip_header_lines must be >= 0 not ", + skip_header_lines)); + Env* env = context->env(); + SetReaderFactory([this, skip_header_lines, env]() { + return new TextLineReader(name(), skip_header_lines, env); + }); + } +}; + +REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU), + TextLineReaderOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/tf_record_reader_op.cc b/tensorflow/core/kernels/tf_record_reader_op.cc new file mode 100644 index 0000000000..551be18d5f --- /dev/null +++ b/tensorflow/core/kernels/tf_record_reader_op.cc @@ -0,0 +1,76 @@ +// See docs in ../ops/io_ops.cc. + +#include +#include "tensorflow/core/framework/reader_op_kernel.h" +#include "tensorflow/core/kernels/reader_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { + +class TFRecordReader : public ReaderBase { + public: + TFRecordReader(const string& node_name, Env* env) + : ReaderBase(strings::StrCat("TFRecordReader '", node_name, "'")), + env_(env), + offset_(0) {} + + Status OnWorkStartedLocked() override { + offset_ = 0; + RandomAccessFile* file = nullptr; + TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file)); + file_.reset(file); + reader_.reset(new io::RecordReader(file)); + return Status::OK(); + } + + Status OnWorkFinishedLocked() override { + reader_.reset(nullptr); + file_.reset(nullptr); + return Status::OK(); + } + + Status ReadLocked(string* key, string* value, bool* produced, + bool* at_end) override { + *key = strings::StrCat(current_work(), ":", offset_); + Status status = reader_->ReadRecord(&offset_, value); + if (errors::IsOutOfRange(status)) { + *at_end = true; + return Status::OK(); + } + if (!status.ok()) return status; + *produced = true; + return Status::OK(); + } + + Status ResetLocked() override { + offset_ = 0; + reader_.reset(nullptr); + file_.reset(nullptr); + return ReaderBase::ResetLocked(); + } + + // TODO(josh11b): Implement serializing and restoring the state. + + private: + Env* const env_; + uint64 offset_; + std::unique_ptr file_; + std::unique_ptr reader_; +}; + +class TFRecordReaderOp : public ReaderOpKernel { + public: + explicit TFRecordReaderOp(OpKernelConstruction* context) + : ReaderOpKernel(context) { + Env* env = context->env(); + SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); }); + } +}; + +REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU), + TFRecordReaderOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc new file mode 100644 index 0000000000..d5e0e89d60 --- /dev/null +++ b/tensorflow/core/kernels/tile_ops.cc @@ -0,0 +1,460 @@ +// See docs in ../ops/array_ops.cc. + +#define EIGEN_USE_THREADS + +#ifdef GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow/core/kernels/tile_ops.h" + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// -------------------------------------------------------------------------- +template +class TileOp : public OpKernel { + public: + explicit TileOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& multiples = context->input(1); + + OP_REQUIRES( + context, TensorShapeUtils::IsLegacyVector(multiples.shape()), + errors::InvalidArgument("Expected multiples to be 1-D, but got shape ", + multiples.shape().ShortDebugString())); + OP_REQUIRES(context, input.dims() == multiples.NumElements(), + errors::InvalidArgument( + "Expected multiples argument to be a vector of length ", + input.dims(), " but got length ", multiples.dim_size(0))); + + const int input_dims = input.dims(); + const gtl::ArraySlice multiples_array(multiples.flat().data(), + input_dims); + + TensorShape output_shape; + for (int i = 0; i < input_dims; ++i) { + OP_REQUIRES( + context, multiples_array[i] > 0, + errors::InvalidArgument("Expected multiples[", i, "] > 0, but got ", + multiples_array[i])); + output_shape.AddDim(input.dim_size(i) * multiples_array[i]); + } + Tensor* result = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result)); + +#define HANDLE_DIM(DT, NDIM) \ + if (context->input(0).dtype() == DT && input_dims == NDIM) { \ + HandleCase(context, multiples_array, result); \ + return; \ + } + +#define HANDLE_TYPE(T) \ + HANDLE_DIM(T, 0) \ + HANDLE_DIM(T, 1) \ + HANDLE_DIM(T, 2) \ + HANDLE_DIM(T, 3) \ + HANDLE_DIM(T, 4) \ + HANDLE_DIM(T, 5) + + HANDLE_TYPE(DT_BOOL); + HANDLE_TYPE(DT_FLOAT); + HANDLE_TYPE(DT_DOUBLE); + HANDLE_TYPE(DT_UINT8); + HANDLE_TYPE(DT_INT32); + HANDLE_TYPE(DT_INT16); + HANDLE_TYPE(DT_INT64); + HANDLE_TYPE(DT_STRING); // when DEVICE=CPUDevice. + +#undef HANDLE_TYPE +#undef HANDLE_DIM + + OP_REQUIRES(context, false, + errors::Unimplemented( + "TileOp : Unhandled input dimensions, DT : ", + context->input(0).dtype(), ", dims : ", input_dims)); + } + + private: + template + void HandleCaseImpl(OpKernelContext* context, + const gtl::ArraySlice& multiples_array, + Tensor* result) { + typedef typename EnumToDataType
::Type T; + Eigen::array broadcast_array; + for (int i = 0; i < NDIM; ++i) { + broadcast_array[i] = multiples_array[i]; + } + functor::Tile()( + context->eigen_device(), result->tensor(), + context->input(0).tensor(), broadcast_array); + } + + template + void HandleCase(OpKernelContext* context, + const gtl::ArraySlice& multiples_array, + Tensor* result); + + TF_DISALLOW_COPY_AND_ASSIGN(TileOp); +}; + +template +template +inline void TileOp::HandleCase( + OpKernelContext* context, const gtl::ArraySlice& multiples_array, + Tensor* result) { + LOG(FATAL) << "TileOp: Invalid combination of Device, DT and NDIM: " + << typeid(Device).name() << ", " << DataTypeString(DT) << ", " + << NDIM; +} + +#define HANDLE_CASE(device, dtype, ndim) \ + template <> \ + template <> \ + void TileOp::HandleCase( \ + OpKernelContext * context, \ + const gtl::ArraySlice& multiples_array, Tensor* result) { \ + HandleCaseImpl(context, multiples_array, result); \ + } + +#define HANDLE_CASE_DIM_POSITIVE(device, dtype) \ + HANDLE_CASE(device, dtype, 1); \ + HANDLE_CASE(device, dtype, 2); \ + HANDLE_CASE(device, dtype, 3); \ + HANDLE_CASE(device, dtype, 4); \ + HANDLE_CASE(device, dtype, 5); + +#define HANDLE_CASE_DIM(device, dtype) \ + HANDLE_CASE(device, dtype, 0); \ + HANDLE_CASE_DIM_POSITIVE(device, dtype); + +HANDLE_CASE_DIM(CPUDevice, DT_BOOL); +HANDLE_CASE_DIM(CPUDevice, DT_FLOAT); +HANDLE_CASE_DIM(CPUDevice, DT_DOUBLE); +HANDLE_CASE_DIM(CPUDevice, DT_UINT8); +HANDLE_CASE_DIM(CPUDevice, DT_INT32); +HANDLE_CASE_DIM(CPUDevice, DT_INT16); +HANDLE_CASE_DIM(CPUDevice, DT_INT64); +HANDLE_CASE_DIM(CPUDevice, DT_STRING); + +#if GOOGLE_CUDA +// Eigen on GPU does not handle 0-dimension data types yet. +HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_FLOAT); +HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_DOUBLE); +HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT16); +HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT32); +HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT64); +#endif // GOOGLE_CUDA + +#undef HANDLE_CASE_DIM_POSITIVE +#undef HANDLE_CASE_DIM +#undef HANDLE_CASE + +// -------------------------------------------------------------------------- +template +class TileGradientOp : public OpKernel { + public: + explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& multiples = context->input(1); + OP_REQUIRES( + context, TensorShapeUtils::IsLegacyVector(multiples.shape()), + errors::InvalidArgument("Expected multiples to be 1-D, but got shape ", + multiples.shape().ShortDebugString())); + OP_REQUIRES(context, input.dims() == multiples.NumElements(), + errors::InvalidArgument( + "Expected multiples argument to be a vector of length ", + input.dims(), " but got length ", multiples.dim_size(0))); + + const int input_dims = input.dims(); + const gtl::ArraySlice multiples_array(multiples.flat().data(), + input_dims); + + TensorShape output_shape; + std::vector input_dim_size_vec; + for (int i = 0; i < input_dims; ++i) { + OP_REQUIRES( + context, multiples_array[i] > 0, + errors::InvalidArgument("Expected multiples[", i, "] > 0, but got ", + multiples_array[i])); + OP_REQUIRES(context, input.dim_size(i) % multiples_array[i] == 0, + errors::InvalidArgument("Expected input_dim[", i, + "] to be divisible by multiples[", i, + "], but ", input.dim_size(i), " % ", + multiples_array[i], " != 0")); + output_shape.AddDim(input.dim_size(i) / multiples_array[i]); + input_dim_size_vec.push_back(input.dim_size(i)); + } + Tensor* result = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result)); + +#define HANDLE_DIM(DT, NDIM) \ + if (context->input(0).dtype() == DT && input_dims == NDIM) { \ + HandleCase(context, input_dim_size_vec, multiples_array, \ + result); \ + return; \ + } + +#define HANDLE_TYPE(T) \ + HANDLE_DIM(T, 0) \ + HANDLE_DIM(T, 1) \ + HANDLE_DIM(T, 2) \ + HANDLE_DIM(T, 3) \ + HANDLE_DIM(T, 4) \ + HANDLE_DIM(T, 5) + + HANDLE_TYPE(DT_FLOAT); + HANDLE_TYPE(DT_DOUBLE); + HANDLE_TYPE(DT_INT32); + HANDLE_TYPE(DT_INT16); + HANDLE_TYPE(DT_INT64); + +#undef HANDLE_TYPE +#undef HANDLE_DIM + + OP_REQUIRES(context, false, + errors::Unimplemented( + "TileGradientOp : Unhandled input dimensions, DT : ", + context->input(0).dtype(), ", dims : ", input_dims)); + } + + private: + template + void HandleCase(OpKernelContext* context, + const std::vector& input_dims, + const gtl::ArraySlice& multiples_array, + Tensor* result); + + template + void HandleCaseImpl(OpKernelContext* context, + const std::vector& input_dims, + const gtl::ArraySlice& multiples_array, + Tensor* result) { + typedef typename EnumToDataType
::Type T; + + bool reduction_only = true; + std::vector reduction_dims; + + for (int i = 0; i < NDIM; ++i) { + if (input_dims[i] > multiples_array[i] && multiples_array[i] > 1) { + reduction_only = false; + break; + } else { + if (multiples_array[i] == input_dims[i]) { + reduction_dims.push_back(i); + } + } + } + + if (reduction_only) { +#define HANDLE_DIM(D) \ + if (reduction_dims.size() == (D)) { \ + HandleReduce(context, reduction_dims, result); \ + return; \ + } + // NOTE(keveman): Handling the most common case here. + // Adding more cases here would require more templating and code + // explosion. For instance, HANDLE_DIM(2) wouldn't make sense for NDIM=1. + HANDLE_DIM(NDIM > 0 ? 1 : 0); + +// Fall through to the unoptimized version. +#undef HANDLE_DIM + } + + Eigen::DSizes indices; + Eigen::DSizes sizes; + + // Accumulate slices along the dimensions into the output. The number of + // slices along dimension 'i' is simply the multiple along dimension 'i' + // passed to the original Tile op. + for (int i = 0; i < NDIM; ++i) { + sizes[i] = input_dims[i] / multiples_array[i]; + indices[i] = 0; + } + + bool first = true; + while (true) { + functor::TileGrad()( + context->eigen_device(), result->tensor(), + context->input(0).tensor(), indices, sizes, first); + first = false; + // Increment the begin indices. + int i = 0; + while (i < NDIM && indices[i] / sizes[i] == multiples_array[i] - 1) { + indices[i] = 0; + ++i; + } + // We are finished if we have iterated to the maximum along all + // dimensions. + if (i == NDIM) { + break; + } + indices[i] += sizes[i]; + } + } + + template + void HandleReduce(OpKernelContext* context, + const std::vector& reduce_dim_in, Tensor* result) { + static_assert(NDIM >= REDUCENDIM, "Too many reduced dimensions"); + Eigen::DSizes reduce_dim; + Eigen::DSizes reshape_dim; + + for (int i = 0; i < REDUCENDIM; ++i) { + reduce_dim[i] = reduce_dim_in[i]; + } + + for (int i = 0; i < NDIM; ++i) { + reshape_dim[i] = result->dim_size(i); + } + + functor::ReduceAndReshape()( + context->eigen_device(), result->tensor(), + context->input(0).tensor(), reduce_dim, reshape_dim); + } + + TF_DISALLOW_COPY_AND_ASSIGN(TileGradientOp); +}; + +template +template +inline void TileGradientOp::HandleCase( + OpKernelContext* context, const std::vector& input_dims, + const gtl::ArraySlice& multiples_array, Tensor* result) { + LOG(FATAL) << "TileGradientOp: Invalid combination of Device, DT and NDIM: " + << typeid(Device).name() << ", " << DataTypeString(DT) << ", " + << NDIM; +} + +#define HANDLE_CASE(device, dtype, ndim) \ + template <> \ + template <> \ + void TileGradientOp::HandleCase( \ + OpKernelContext * context, const std::vector& input_dims, \ + const gtl::ArraySlice& multiples_array, Tensor* result) { \ + HandleCaseImpl(context, input_dims, multiples_array, result); \ + } + +#define HANDLE_CASE_DIM_POSITIVE(device, dtype) \ + HANDLE_CASE(device, dtype, 1); \ + HANDLE_CASE(device, dtype, 2); \ + HANDLE_CASE(device, dtype, 3); \ + HANDLE_CASE(device, dtype, 4); \ + HANDLE_CASE(device, dtype, 5); + +#define HANDLE_CASE_DIM(device, dtype) \ + HANDLE_CASE(device, dtype, 0); \ + HANDLE_CASE_DIM_POSITIVE(device, dtype); + +HANDLE_CASE_DIM(CPUDevice, DT_FLOAT); +HANDLE_CASE_DIM(CPUDevice, DT_DOUBLE); +HANDLE_CASE_DIM(CPUDevice, DT_INT16); +HANDLE_CASE_DIM(CPUDevice, DT_INT32); +HANDLE_CASE_DIM(CPUDevice, DT_INT64); + +#if GOOGLE_CUDA +// Eigen on GPU does not handle 0-dimension data types yet. +HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_FLOAT); +HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_DOUBLE); +HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT16); +HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT32); +HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT64); +#endif // GOOGLE_CUDA + +#undef HANDLE_CASE_DIM_POSITIVE +#undef HANDLE_CASE_DIM +#undef HANDLE_CASE + +REGISTER_KERNEL_BUILDER(Name("Tile").Device(DEVICE_CPU).HostMemory("multiples"), + TileOp); +REGISTER_KERNEL_BUILDER(Name("TileGrad") + .Device(DEVICE_CPU) + .HostMemory("multiples"), + TileGradientOp); + +#if GOOGLE_CUDA +#define DEFINE_GPU_TYPE(T) \ + DEFINE_GPU_DIM(T, 1) \ + DEFINE_GPU_DIM(T, 2) \ + DEFINE_GPU_DIM(T, 3) \ + DEFINE_GPU_DIM(T, 4) \ + DEFINE_GPU_DIM(T, 5) + +#define DEFINE_GPU_DIM(T, NDIM) \ + template <> \ + void Tile::operator()( \ + const GPUDevice& d, typename TTypes::Tensor out, \ + typename TTypes::ConstTensor in, \ + const Eigen::array& broadcast_array) const; \ + extern template struct Tile; \ + template <> \ + void TileGrad::operator()( \ + const GPUDevice& d, typename TTypes::Tensor out, \ + typename TTypes::ConstTensor in, \ + const Eigen::DSizes& indices, \ + const Eigen::DSizes& sizes, bool first) const; \ + extern template struct TileGrad; \ + template <> \ + void ReduceAndReshape::operator()( \ + const GPUDevice& d, typename TTypes::Tensor out, \ + typename TTypes::ConstTensor in, \ + const Eigen::DSizes& reduce_dim, \ + const Eigen::DSizes& reshape_dim) const; \ + extern template struct ReduceAndReshape; + +namespace functor { +DEFINE_GPU_TYPE(float); +DEFINE_GPU_TYPE(double); +DEFINE_GPU_TYPE(int64); +DEFINE_GPU_TYPE(int32); +DEFINE_GPU_TYPE(int16); +} // end namespace functor + +#undef DEFINE_GPU_DIM +#undef DEFINE_GPU_TYPE + +REGISTER_KERNEL_BUILDER(Name("Tile") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("multiples"), + TileOp); +REGISTER_KERNEL_BUILDER(Name("Tile") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("multiples"), + TileOp); +REGISTER_KERNEL_BUILDER(Name("Tile") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("multiples"), + TileOp); + +REGISTER_KERNEL_BUILDER(Name("TileGrad") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("multiples"), + TileGradientOp); +REGISTER_KERNEL_BUILDER(Name("TileGrad") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("multiples"), + TileGradientOp); +REGISTER_KERNEL_BUILDER(Name("TileGrad") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("multiples"), + TileGradientOp); +#endif // GOOGLE_CUDA +} // namespace tensorflow diff --git a/tensorflow/core/kernels/tile_ops.h b/tensorflow/core/kernels/tile_ops.h new file mode 100644 index 0000000000..b3cc6165e0 --- /dev/null +++ b/tensorflow/core/kernels/tile_ops.h @@ -0,0 +1,48 @@ +#ifndef TENSORFLOW_KERNELS_TILE_OPS_H_ +#define TENSORFLOW_KERNELS_TILE_OPS_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +template +struct Tile { + void operator()(const Device& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in, + const Eigen::array& broadcast_array) const { + out.device(d) = in.broadcast(broadcast_array); + } +}; + +template +struct TileGrad { + void operator()(const Device& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in, + const Eigen::DSizes& indices, + const Eigen::DSizes& sizes, + bool first) const { + if (first) { + out.device(d) = in.slice(indices, sizes); + } else { + out.device(d) += in.slice(indices, sizes); + } + } +}; + +template +struct ReduceAndReshape { + void operator()(const Device& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in, + const Eigen::DSizes& reduce_dim, + const Eigen::DSizes& reshape_dim) const { + out.device(d) = in.sum(reduce_dim).reshape(reshape_dim); + } +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_KERNELS_TILE_OPS_H_ diff --git a/tensorflow/core/kernels/tile_ops_gpu.cu.cc b/tensorflow/core/kernels/tile_ops_gpu.cu.cc new file mode 100644 index 0000000000..29481e1a54 --- /dev/null +++ b/tensorflow/core/kernels/tile_ops_gpu.cu.cc @@ -0,0 +1,38 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/tile_ops.h" +#include + +namespace tensorflow { +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; + +#define DEFINE_TYPE(T) \ + DEFINE_DIM(T, 1) \ + DEFINE_DIM(T, 2) \ + DEFINE_DIM(T, 3) \ + DEFINE_DIM(T, 4) \ + DEFINE_DIM(T, 5) + +#define DEFINE_DIM(T, NDIM) \ + template struct Tile; \ + template struct TileGrad; \ + template struct ReduceAndReshape; + +DEFINE_TYPE(float) +DEFINE_TYPE(double) +DEFINE_TYPE(int64) +DEFINE_TYPE(int32) +DEFINE_TYPE(int16) +// NOTE(keveman): Eigen's int8 and string versions don't compile yet with nvcc. + +#undef DEFINE_DIM +#undef DEFINE_TYPE + +} // end namespace functor +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc new file mode 100644 index 0000000000..79b5d4d07e --- /dev/null +++ b/tensorflow/core/kernels/topk_op.cc @@ -0,0 +1,71 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/gtl/top_n.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +template +class TopK : public OpKernel { + public: + explicit TopK(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); + } + + void Compute(OpKernelContext* context) override { + const auto& input_in = context->input(0); + OP_REQUIRES(context, input_in.dims() == 2, + errors::InvalidArgument("input must be 2-dimensional")); + OP_REQUIRES(context, input_in.dim_size(1) >= k_, + errors::InvalidArgument("input must have at least k columns")); + + const auto& input = input_in.matrix(); + + const auto num_rows = input_in.dim_size(0); // generally batch_size + const auto num_cols = input_in.dim_size(1); + + Tensor* values_out = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({num_rows, k_}), &values_out)); + Tensor* indices_out = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 1, TensorShape({num_rows, k_}), &indices_out)); + auto values = values_out->matrix(); + auto indices = indices_out->matrix(); + + gtl::TopN> filter(k_); + + for (int r = 0; r < num_rows; r++) { + for (int32 c = 0; c < num_cols; ++c) { + // The second element is the negated index, so that lower-index elements + // are considered larger than higher-index elements in case of ties. + filter.push(std::make_pair(input(r, c), -c)); + } + + std::unique_ptr>> top_k(filter.Extract()); + for (int32 i = 0; i < k_; ++i) { + values(r, i) = (*top_k)[i].first; + indices(r, i) = -(*top_k)[i].second; + } + filter.Reset(); + } + } + + private: + int k_; +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("TopK").Device(DEVICE_CPU).TypeConstraint("T"), TopK) + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc new file mode 100644 index 0000000000..611fa4ac41 --- /dev/null +++ b/tensorflow/core/kernels/training_ops.cc @@ -0,0 +1,884 @@ +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/training_ops.h" + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +static inline bool DoInline(int64 size) { return size <= (256ll << 10); } + +template +struct ApplyGradientDescent { + void operator()(const CPUDevice& d, typename TTypes::Flat var, + typename TTypes::ConstScalar lr, + typename TTypes::ConstFlat grad) { + if (DoInline(var.size())) { + var -= grad * lr(); + } else { + var.device(d) -= grad * lr(); + } + } +}; + +template +struct ApplyAdagrad { + void operator()(const CPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstFlat grad) { + if (DoInline(var.size())) { + accum += grad.square(); + var -= grad * lr() * accum.rsqrt(); + } else { + accum.device(d) += grad.square(); + var.device(d) -= grad * lr() * accum.rsqrt(); + } + } +}; + +template +struct ApplyMomentum { + void operator()(const CPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar momentum) { + if (DoInline(var.size())) { + accum = accum * momentum() + grad; + var -= accum * lr(); + } else { + accum.device(d) = accum * momentum() + grad; + var.device(d) -= accum * lr(); + } + } +}; + +template +struct ApplyAdam { + void operator()(const CPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat m, typename TTypes::Flat v, + typename TTypes::ConstScalar beta1_power, + typename TTypes::ConstScalar beta2_power, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar beta1, + typename TTypes::ConstScalar beta2, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad) { + const T alpha = lr() * std::sqrt(1 - beta2_power()) / (1 - beta1_power()); + if (DoInline(var.size())) { + m += (grad - m) * (1 - beta1()); + v += (grad.square() - v) * (1 - beta2()); + var -= (m * alpha) / (v.sqrt() + epsilon()); + } else { + m.device(d) += (grad - m) * (1 - beta1()); + v.device(d) += (grad.square() - v) * (1 - beta2()); + var.device(d) -= (m * alpha) / (v.sqrt() + epsilon()); + } + } +}; + +template +struct ApplyRMSProp { + void operator()(const CPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat ms, typename TTypes::Flat mom, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar rho, + typename TTypes::ConstScalar momentum, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad) { + if (DoInline(var.size())) { + ms += (grad.square() - ms) * (1 - rho()); + mom = mom * momentum() + (grad * lr()) / ((ms + epsilon()).sqrt()); + var -= mom; + } else { + ms.device(d) += (grad.square() - ms) * (1 - rho()); + mom.device(d) = + mom * momentum() + (grad * lr()) / ((ms + epsilon()).sqrt()); + var.device(d) -= mom; + } + } +}; + +} // namespace functor + +template +class ApplyGradientDescentOp : public OpKernel { + public: + explicit ApplyGradientDescentOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override { + if (use_exclusive_lock_) { + mutex_lock l(*ctx->input_ref_mutex(0)); + DoValidate(ctx); + if (!ctx->status().ok()) return; + DoCompute(ctx); + } else { + DoValidate(ctx); + if (!ctx->status().ok()) return; + DoCompute(ctx); + } + ctx->forward_ref_input_to_ref_output(0, 0); + } + + private: + bool use_exclusive_lock_; + + void DoValidate(OpKernelContext* ctx) { + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(0))); + const Tensor& alpha = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(alpha.shape()), + errors::InvalidArgument("alpha is not a scalar: ", + alpha.shape().DebugString())); + const Tensor& delta = ctx->input(2); + OP_REQUIRES( + ctx, var.shape().IsSameSize(delta.shape()), + errors::InvalidArgument("var and delta do not have the same shape", + var.shape().DebugString(), " ", + delta.shape().DebugString())); + } + + void DoCompute(OpKernelContext* ctx) { + const Device& device = ctx->template eigen_device(); + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + const Tensor& alpha = ctx->input(1); + const Tensor& delta = ctx->input(2); + functor::ApplyGradientDescent()( + device, var.flat(), alpha.scalar(), delta.flat()); + } +}; + +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ApplyGradientDescent").Device(DEVICE_##D).TypeConstraint("T"), \ + ApplyGradientDescentOp); + +REGISTER_KERNELS(CPU, float); +REGISTER_KERNELS(CPU, double); + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ApplyGradientDescent::operator()( \ + const GPUDevice& d, typename TTypes::Flat var, \ + typename TTypes::ConstScalar alpha, \ + typename TTypes::ConstFlat delta); \ + extern template struct ApplyGradientDescent; +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNELS(GPU, float); +REGISTER_KERNELS(GPU, double); +#endif +#undef REGISTER_KERNELS + +template +class ApplyAdagradOp : public OpKernel { + public: + explicit ApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override { + if (use_exclusive_lock_) { + mutex_lock l1(*ctx->input_ref_mutex(0)); + // Don't try to acquire a lock on the second ref as they share the same + // mutex. + // + // mutex_lock l2(*ctx->input_ref_mutex(1)); + DoValidate(ctx); + if (!ctx->status().ok()) return; + DoCompute(ctx); + } else { + DoValidate(ctx); + if (!ctx->status().ok()) return; + DoCompute(ctx); + } + ctx->forward_ref_input_to_ref_output(0, 0); + } + + private: + bool use_exclusive_lock_; + + void DoValidate(OpKernelContext* ctx) { + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(0))); + OP_REQUIRES( + ctx, accum.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(1))); + const Tensor& lr = ctx->input(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()), + errors::InvalidArgument("lr is not a scalar: ", + lr.shape().DebugString())); + const Tensor& grad = ctx->input(3); + OP_REQUIRES( + ctx, var.shape().IsSameSize(accum.shape()), + errors::InvalidArgument("var and accum do not have the same shape", + var.shape().DebugString(), " ", + accum.shape().DebugString())); + OP_REQUIRES( + ctx, var.shape().IsSameSize(grad.shape()), + errors::InvalidArgument("var and delta do not have the same shape", + var.shape().DebugString(), " ", + grad.shape().DebugString())); + } + + void DoCompute(OpKernelContext* ctx) { + const Device& device = ctx->template eigen_device(); + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); + const Tensor& lr = ctx->input(2); + const Tensor& grad = ctx->input(3); + functor::ApplyAdagrad()(device, var.flat(), accum.flat(), + lr.scalar(), grad.flat()); + } +}; + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint("T"), \ + ApplyAdagradOp); + +REGISTER_KERNELS(CPU, float); +REGISTER_KERNELS(CPU, double); + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ApplyAdagrad::operator()( \ + const GPUDevice& d, typename TTypes::Flat var, \ + typename TTypes::Flat accum, typename TTypes::ConstScalar lr, \ + typename TTypes::ConstFlat grad); \ + extern template struct ApplyAdagrad; +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNELS(GPU, float); +REGISTER_KERNELS(GPU, double); +#endif +#undef REGISTER_KERNELS + +// Note, this op works on cpu only. +template +class SparseApplyAdagradOp : public OpKernel { + public: + explicit SparseApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { + mutex* mu_var = ctx->input_ref_mutex(0); + // mu_accum is actually the same mutex as mu_var since currently we use a + // global mutex. + // + // mutex* mu_accum = ctx->input_ref_mutex(1); + if (use_exclusive_lock_) { + mu_var->lock(); + } + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(0))); + OP_REQUIRES( + ctx, accum.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(1))); + OP_REQUIRES( + ctx, var.shape().IsSameSize(accum.shape()), + errors::InvalidArgument("var and accum do not have the same shape", + var.shape().DebugString(), " ", + accum.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), + errors::InvalidArgument("var must be at least 1 dimensional")); + + const Tensor& lr = ctx->input(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()), + errors::InvalidArgument("lr is not a scalar: ", + lr.shape().DebugString())); + const Tensor& grad = ctx->input(3); + const Tensor& indices = ctx->input(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), + errors::InvalidArgument("indices must be one-dimensional")); + + for (int d = 1; d < var.dims(); d++) { + OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), + errors::InvalidArgument(strings::StrCat( + "var and grad must match in dimension ", d))); + } + const Tindex N = indices.dim_size(0); + OP_REQUIRES( + ctx, grad.dim_size(0) == N, + errors::InvalidArgument( + "grad must be the same size as indices in the first dimension.")); + + if (N > 0) { + const Tindex first_dim_size = var.dim_size(0); + // Validate all the indices are in range + auto indices_vec = indices.vec(); + for (Tindex i = 0; i < N; i++) { + const Tindex index = indices_vec(i); + OP_REQUIRES(ctx, index >= 0 && index < first_dim_size, + errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in indices is out of range"))); + } + + auto var_flat = var.flat_outer_dims(); + auto accum_flat = accum.flat_outer_dims(); + auto grad_flat = grad.flat_outer_dims(); + T lr_scalar = lr.scalar()(); + + // Note(yonghui): It might be worth multi-threading square() and rsqrt(). + for (Tindex i = 0; i < N; i++) { + const Tindex index = indices_vec(i); + auto a = accum_flat.template chip<0>(index); + auto g = grad_flat.template chip<0>(i); + auto v = var_flat.template chip<0>(index); + a += g.square(); + v -= g.constant(lr_scalar) * g * a.rsqrt(); + } + } + if (use_exclusive_lock_) { + mu_var->unlock(); + } + + ctx->forward_ref_input_to_ref_output(0, 0); + } + + private: + bool use_exclusive_lock_; +}; + +#define REGISTER_KERNELS(T, Tindices) \ + REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + SparseApplyAdagradOp); + +REGISTER_KERNELS(float, int32); +REGISTER_KERNELS(float, int64); +REGISTER_KERNELS(double, int32); +REGISTER_KERNELS(double, int64); +#undef REGISTER_KERNELS + +template +class ApplyMomentumOp : public OpKernel { + public: + explicit ApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override { + if (use_exclusive_lock_) { + mutex_lock l1(*ctx->input_ref_mutex(0)); + // Don't try to acquire a lock on the second ref as they share the same + // mutex. + // + // mutex_lock l2(*ctx->input_ref_mutex(1)); + DoValidate(ctx); + if (!ctx->status().ok()) return; + DoCompute(ctx); + } else { + DoValidate(ctx); + if (!ctx->status().ok()) return; + DoCompute(ctx); + } + ctx->forward_ref_input_to_ref_output(0, 0); + } + + private: + bool use_exclusive_lock_; + + void DoValidate(OpKernelContext* ctx) { + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(0))); + OP_REQUIRES( + ctx, accum.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(1))); + const Tensor& lr = ctx->input(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), + errors::InvalidArgument("lr is not a scalar: ", + lr.shape().DebugString())); + const Tensor& grad = ctx->input(3); + OP_REQUIRES( + ctx, var.shape().IsSameSize(accum.shape()), + errors::InvalidArgument("var and accum do not have the same shape", + var.shape().DebugString(), " ", + accum.shape().DebugString())); + OP_REQUIRES( + ctx, var.shape().IsSameSize(grad.shape()), + errors::InvalidArgument("var and delta do not have the same shape", + var.shape().DebugString(), " ", + grad.shape().DebugString())); + + const Tensor& momentum = ctx->input(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), + errors::InvalidArgument("momentum is not a scalar: ", + momentum.shape().DebugString())); + } + + void DoCompute(OpKernelContext* ctx) { + const Device& device = ctx->template eigen_device(); + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); + const Tensor& lr = ctx->input(2); + const Tensor& grad = ctx->input(3); + const Tensor& momentum = ctx->input(4); + functor::ApplyMomentum()(device, var.flat(), accum.flat(), + lr.scalar(), grad.flat(), + momentum.scalar()); + } +}; + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint("T"), \ + ApplyMomentumOp); + +REGISTER_KERNELS(CPU, float); +REGISTER_KERNELS(CPU, double); + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ApplyMomentum::operator()( \ + const GPUDevice& d, typename TTypes::Flat var, \ + typename TTypes::Flat accum, typename TTypes::ConstScalar lr, \ + typename TTypes::ConstFlat grad, \ + typename TTypes::ConstScalar momentum); \ + extern template struct ApplyMomentum; +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNELS(GPU, float); +REGISTER_KERNELS(GPU, double); +#endif +#undef REGISTER_KERNELS + +// Note, this op works on cpu only. +template +class SparseApplyMomentumOp : public OpKernel { + public: + explicit SparseApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { + mutex* mu_var = ctx->input_ref_mutex(0); + // mu_accum is actually the same mutex as mu_var since currently we use a + // global mutex. + // + // mutex* mu_accum = ctx->input_ref_mutex(1); + if (use_exclusive_lock_) { + mu_var->lock(); + } + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(0))); + OP_REQUIRES( + ctx, accum.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(1))); + OP_REQUIRES( + ctx, var.shape().IsSameSize(accum.shape()), + errors::InvalidArgument("var and accum do not have the same shape", + var.shape().DebugString(), " ", + accum.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), + errors::InvalidArgument("var must be at least 1 dimensional")); + + const Tensor& lr = ctx->input(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), + errors::InvalidArgument("lr is not a scalar: ", + lr.shape().DebugString())); + const Tensor& grad = ctx->input(3); + const Tensor& indices = ctx->input(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), + errors::InvalidArgument("indices must be one-dimensional")); + + for (int d = 1; d < var.dims(); d++) { + OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), + errors::InvalidArgument(strings::StrCat( + "var and grad must match in dimension ", d))); + } + const Tindex N = indices.dim_size(0); + OP_REQUIRES( + ctx, grad.dim_size(0) == N, + errors::InvalidArgument( + "grad must be the same size as indices in the first dimension.")); + + const Tensor& momentum = ctx->input(5); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), + errors::InvalidArgument("momentum is not a scalar: ", + momentum.shape().DebugString())); + + if (N > 0) { + const Tindex first_dim_size = var.dim_size(0); + // Validate all the indices are in range + auto indices_vec = indices.vec(); + for (Tindex i = 0; i < N; i++) { + const Tindex index = indices_vec(i); + OP_REQUIRES(ctx, index >= 0 && index < first_dim_size, + errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in indices is out of range"))); + } + + auto var_flat = var.flat_outer_dims(); + auto accum_flat = accum.flat_outer_dims(); + auto grad_flat = grad.flat_outer_dims(); + T lr_scalar = lr.scalar()(); + T momentum_scalar = momentum.scalar()(); + + for (Tindex i = 0; i < N; i++) { + const Tindex index = indices_vec(i); + auto a = accum_flat.template chip<0>(index); + auto g = grad_flat.template chip<0>(i); + auto v = var_flat.template chip<0>(index); + a = a * a.constant(momentum_scalar) + g; + v -= a.constant(lr_scalar) * a; + } + } + if (use_exclusive_lock_) { + mu_var->unlock(); + } + + ctx->forward_ref_input_to_ref_output(0, 0); + } + + private: + bool use_exclusive_lock_; +}; + +#define REGISTER_KERNELS(T, Tindices) \ + REGISTER_KERNEL_BUILDER(Name("SparseApplyMomentum") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + SparseApplyMomentumOp); + +REGISTER_KERNELS(float, int32); +REGISTER_KERNELS(float, int64); +REGISTER_KERNELS(double, int32); +REGISTER_KERNELS(double, int64); +#undef REGISTER_KERNELS + +template +class ApplyAdamOp : public OpKernel { + public: + explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override { + if (use_exclusive_lock_) { + // all input refs share the same mutex + mutex_lock l1(*ctx->input_ref_mutex(0)); + DoValidate(ctx); + if (!ctx->status().ok()) return; + DoCompute(ctx); + } else { + DoValidate(ctx); + if (!ctx->status().ok()) return; + DoCompute(ctx); + } + ctx->forward_ref_input_to_ref_output(0, 0); + } + + private: + bool use_exclusive_lock_; + + void DoValidate(OpKernelContext* ctx) { + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor m = ctx->mutable_input(1, use_exclusive_lock_); + Tensor v = ctx->mutable_input(2, use_exclusive_lock_); + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(0))); + OP_REQUIRES( + ctx, m.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(1))); + OP_REQUIRES( + ctx, v.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(2))); + + const Tensor& beta1_power = ctx->input(3); + const Tensor& beta2_power = ctx->input(4); + const Tensor& lr = ctx->input(5); + const Tensor& beta1 = ctx->input(6); + const Tensor& beta2 = ctx->input(7); + const Tensor& epsilon = ctx->input(8); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()), + errors::InvalidArgument("beta1_power is not a scalar: ", + beta1_power.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power.shape()), + errors::InvalidArgument("beta2_power is not a scalar: ", + beta2_power.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), + errors::InvalidArgument("lr is not a scalar: ", + lr.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()), + errors::InvalidArgument("beta1 is not a scalar: ", + beta1.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()), + errors::InvalidArgument("beta2 is not a scalar: ", + beta2.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon.shape().DebugString())); + + const Tensor& grad = ctx->input(9); + OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), + errors::InvalidArgument("var and m do not have the same shape", + var.shape().DebugString(), " ", + m.shape().DebugString())); + OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()), + errors::InvalidArgument("var and v do not have the same shape", + var.shape().DebugString(), " ", + v.shape().DebugString())); + OP_REQUIRES( + ctx, var.shape().IsSameSize(grad.shape()), + errors::InvalidArgument("var and grad do not have the same shape", + var.shape().DebugString(), " ", + grad.shape().DebugString())); + } + + void DoCompute(OpKernelContext* ctx) { + const Device& device = ctx->template eigen_device(); + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor m = ctx->mutable_input(1, use_exclusive_lock_); + Tensor v = ctx->mutable_input(2, use_exclusive_lock_); + const Tensor& beta1_power = ctx->input(3); + const Tensor& beta2_power = ctx->input(4); + const Tensor& lr = ctx->input(5); + const Tensor& beta1 = ctx->input(6); + const Tensor& beta2 = ctx->input(7); + const Tensor& epsilon = ctx->input(8); + const Tensor& grad = ctx->input(9); + + functor::ApplyAdam()(device, var.flat(), m.flat(), + v.flat(), beta1_power.scalar(), + beta2_power.scalar(), lr.scalar(), + beta1.scalar(), beta2.scalar(), + epsilon.scalar(), grad.flat()); + } +}; + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint("T"), \ + ApplyAdamOp); + +REGISTER_KERNELS(CPU, float); +REGISTER_KERNELS(CPU, double); + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ApplyAdam::operator()( \ + const GPUDevice& d, typename TTypes::Flat var, \ + typename TTypes::Flat m, typename TTypes::Flat v, \ + typename TTypes::ConstScalar beta1_power, \ + typename TTypes::ConstScalar beta2_power, \ + typename TTypes::ConstScalar lr, \ + typename TTypes::ConstScalar beta1, \ + typename TTypes::ConstScalar beta2, \ + typename TTypes::ConstScalar epsilon, \ + typename TTypes::ConstFlat grad); \ + extern template struct ApplyAdam; +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNELS(GPU, float); +REGISTER_KERNELS(GPU, double); +#endif +#undef REGISTER_KERNELS + +template +class ApplyRMSPropOp : public OpKernel { + public: + explicit ApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override { + if (use_exclusive_lock_) { + // all input refs share the same mutex + mutex_lock l1(*ctx->input_ref_mutex(0)); + DoValidate(ctx); + if (!ctx->status().ok()) return; + DoCompute(ctx); + } else { + DoValidate(ctx); + if (!ctx->status().ok()) return; + DoCompute(ctx); + } + ctx->forward_ref_input_to_ref_output(0, 0); + } + + private: + bool use_exclusive_lock_; + + void DoValidate(OpKernelContext* ctx) { + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor ms = ctx->mutable_input(1, use_exclusive_lock_); + Tensor mom = ctx->mutable_input(2, use_exclusive_lock_); + + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(0))); + OP_REQUIRES( + ctx, ms.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(1))); + OP_REQUIRES( + ctx, mom.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(2))); + + const Tensor& lr = ctx->input(3); + const Tensor& rho = ctx->input(4); + const Tensor& momentum = ctx->input(5); + const Tensor& epsilon = ctx->input(6); + const Tensor& grad = ctx->input(7); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), + errors::InvalidArgument("lr is not a scalar: ", + lr.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), + errors::InvalidArgument("rho is not a scalar: ", + rho.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), + errors::InvalidArgument("momentum is not a scalar: ", + momentum.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon.shape().DebugString())); + + OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()), + errors::InvalidArgument("var and ms do not have the same shape", + var.shape().DebugString(), " ", + ms.shape().DebugString())); + + OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()), + errors::InvalidArgument( + "var and mom do not have the same shape", + var.shape().DebugString(), " ", mom.shape().DebugString())); + + OP_REQUIRES( + ctx, var.shape().IsSameSize(grad.shape()), + errors::InvalidArgument("var and grad do not have the same shape", + var.shape().DebugString(), " ", + grad.shape().DebugString())); + } + + void DoCompute(OpKernelContext* ctx) { + const Device& device = ctx->template eigen_device(); + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor ms = ctx->mutable_input(1, use_exclusive_lock_); + Tensor mom = ctx->mutable_input(2, use_exclusive_lock_); + const Tensor& lr = ctx->input(3); + const Tensor& rho = ctx->input(4); + const Tensor& momentum = ctx->input(5); + const Tensor& epsilon = ctx->input(6); + const Tensor& grad = ctx->input(7); + + functor::ApplyRMSProp()(device, var.flat(), ms.flat(), + mom.flat(), lr.scalar(), + rho.scalar(), momentum.scalar(), + epsilon.scalar(), grad.flat()); + } +}; + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ApplyRMSProp").Device(DEVICE_##D).TypeConstraint("T"), \ + ApplyRMSPropOp); + +REGISTER_KERNELS(CPU, float); +REGISTER_KERNELS(CPU, double); + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ApplyRMSProp::operator()( \ + const GPUDevice& d, typename TTypes::Flat var, \ + typename TTypes::Flat ms, typename TTypes::Flat mom, \ + typename TTypes::ConstScalar lr, typename TTypes::ConstScalar rho, \ + typename TTypes::ConstScalar momentum, \ + typename TTypes::ConstScalar epsilon, \ + typename TTypes::ConstFlat grad); \ + extern template struct ApplyRMSProp; +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNELS(GPU, float); +REGISTER_KERNELS(GPU, double); +#endif +#undef REGISTER_KERNELS + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h new file mode 100644 index 0000000000..71f6d0253d --- /dev/null +++ b/tensorflow/core/kernels/training_ops.h @@ -0,0 +1,65 @@ +#ifndef TENSORFLOW_KERNELS_TRAINING_OPS_H_ +#define TENSORFLOW_KERNELS_TRAINING_OPS_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Each training algorithm has a ApplyXYZ functor struct declared in +// this header file. They are specialized for different devices +// (CPUDevice in training_ops.cc or GPUDevice in training_ops_gpu.cc). + +template +struct ApplyGradientDescent { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::ConstScalar alpha, + typename TTypes::ConstFlat delta); +}; + +template +struct ApplyAdagrad { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstFlat grad); +}; + +template +struct ApplyMomentum { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar momentum); +}; + +template +struct ApplyAdam { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat m, typename TTypes::Flat v, + typename TTypes::ConstScalar beta1_power, + typename TTypes::ConstScalar beta2_power, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar beta1, + typename TTypes::ConstScalar beta2, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad); +}; + +template +struct ApplyRMSProp { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat ms, typename TTypes::Flat mom, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar rho, + typename TTypes::ConstScalar momentum, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad); +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_KERNELS_TRAINING_OPS_H_ diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc new file mode 100644 index 0000000000..3106f29648 --- /dev/null +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -0,0 +1,127 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/training_ops.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { +template +struct ApplyGradientDescent { + void operator()(const GPUDevice& d, typename TTypes::Flat var, + typename TTypes::ConstScalar alpha, + typename TTypes::ConstFlat delta) { + Eigen::array::Tensor::Index, 1> bcast; + bcast[0] = delta.dimension(0); + Eigen::Sizes<1> single; + var.device(d) -= alpha.reshape(single).broadcast(bcast) * delta; + } +}; + +template +struct ApplyAdagrad { + void operator()(const GPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstFlat grad) { + accum.device(d) += grad.square(); + Eigen::array::Tensor::Index, 1> bcast; + bcast[0] = grad.dimension(0); + Eigen::Sizes<1> single; + var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt(); + } +}; + +template +struct ApplyMomentum { + void operator()(const GPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar momentum) { + Eigen::array::Tensor::Index, 1> bcast; + bcast[0] = grad.dimension(0); + Eigen::Sizes<1> single; + accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad; + var.device(d) -= lr.reshape(single).broadcast(bcast) * accum; + } +}; + +template +struct ApplyAdam { + void operator()(const GPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat m, typename TTypes::Flat v, + typename TTypes::ConstScalar beta1_power, + typename TTypes::ConstScalar beta2_power, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar beta1, + typename TTypes::ConstScalar beta2, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad) { + Eigen::array::Tensor::Index, 1> bcast; + bcast[0] = grad.dimension(0); + Eigen::Sizes<1> single; + const auto one = static_cast(1.0); + m.device(d) = + m + + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) * + (grad - m); + v.device(d) = + v + + (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) * + (grad.square() - v); + var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() / + (beta1_power.constant(one) - beta1_power)) + .reshape(single) + .broadcast(bcast) * + m / (epsilon.reshape(single).broadcast(bcast) + v.sqrt()); + } +}; + +template +struct ApplyRMSProp { + void operator()(const GPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat ms, typename TTypes::Flat mom, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar rho, + typename TTypes::ConstScalar momentum, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad) { + Eigen::array::Tensor::Index, 1> bcast; + bcast[0] = grad.dimension(0); + Eigen::Sizes<1> single; + const auto one = static_cast(1.0); + ms.device(d) = ms + + (rho.constant(one) - rho).reshape(single).broadcast(bcast) * + (grad.square() - ms); + mom.device(d) = + mom * momentum.reshape(single).broadcast(bcast) + + lr.reshape(single).broadcast(bcast) * grad / + ((epsilon.reshape(single).broadcast(bcast) + ms).sqrt()); + var.device(d) -= mom; + } +}; + +} // namespace functor + +template struct functor::ApplyGradientDescent; +template struct functor::ApplyGradientDescent; + +template struct functor::ApplyAdagrad; +template struct functor::ApplyAdagrad; + +template struct functor::ApplyMomentum; +template struct functor::ApplyMomentum; + +template struct functor::ApplyAdam; +template struct functor::ApplyAdam; + +template struct functor::ApplyRMSProp; +template struct functor::ApplyRMSProp; +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/training_ops_test.cc b/tensorflow/core/kernels/training_ops_test.cc new file mode 100644 index 0000000000..3c629badb6 --- /dev/null +++ b/tensorflow/core/kernels/training_ops_test.cc @@ -0,0 +1,226 @@ +#include +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +// We focus on the single thread performance of training ops. +static SessionOptions InitSingleThreadedOptions() { + SessionOptions opts; + opts.config.set_intra_op_parallelism_threads(1); + opts.config.set_inter_op_parallelism_threads(1); + return opts; +} + +static SessionOptions* GetOptions() { + static SessionOptions opts = InitSingleThreadedOptions(); + return &opts; +} + +static Node* Var(Graph* g, int n) { + return test::graph::Var(g, DT_FLOAT, TensorShape({n})); +} + +static Node* Zeros(Graph* g, int n) { + Tensor data(DT_FLOAT, TensorShape({n})); + data.flat().setZero(); + return test::graph::Constant(g, data); +} + +static Node* Random(Graph* g, int n) { + Tensor data(DT_FLOAT, TensorShape({n})); + data.flat().setRandom(); + return test::graph::Constant(g, data); +} + +static Node* Scalar(Graph* g, float val) { + Tensor data(DT_FLOAT, TensorShape({})); + data.flat()(0) = val; + return test::graph::Constant(g, data); +} + +static void SGD(int32 n, Graph** init_g, Graph** train_g) { + RequireDefaultOps(); + { + Graph* g = new Graph(OpRegistry::Global()); + auto var = Var(g, n); + test::graph::Assign(g, var, Zeros(g, n)); + *init_g = g; + } + { + Graph* g = new Graph(OpRegistry::Global()); + auto var = Var(g, n); + auto lr = Scalar(g, 0.01); + auto grad = Random(g, n); + test::graph::Multi(g, "ApplyGradientDescent", {var, lr, grad}); + *train_g = g; + } +} + +static void BM_SGD(int iters, int params) { + const int64 tot = static_cast(iters) * params; + testing::ItemsProcessed(tot); + testing::BytesProcessed(tot * sizeof(float)); + Graph* init; + Graph* train; + SGD(params, &init, &train); + test::Benchmark("cpu", train, GetOptions(), init).Run(iters); +} +BENCHMARK(BM_SGD)->Arg(128 << 10)->Arg(256 << 10); + +static void Adagrad(int32 n, Graph** init_g, Graph** train_g) { + RequireDefaultOps(); + { + Graph* g = new Graph(OpRegistry::Global()); + auto var = Var(g, n); + auto accum = Var(g, n); + auto zero = Zeros(g, n); + test::graph::Assign(g, var, zero); + test::graph::Assign(g, accum, zero); + *init_g = g; + } + { + Graph* g = new Graph(OpRegistry::Global()); + auto var = Var(g, n); + auto accum = Var(g, n); + auto lr = Scalar(g, 0.01); + auto grad = Random(g, n); + test::graph::Multi(g, "ApplyAdagrad", {var, accum, lr, grad}); + *train_g = g; + } +} + +static void BM_Adagrad(int iters, int params) { + const int64 tot = static_cast(iters) * params; + testing::ItemsProcessed(tot); + testing::BytesProcessed(tot * sizeof(float)); + Graph* init; + Graph* train; + Adagrad(params, &init, &train); + test::Benchmark("cpu", train, GetOptions(), init).Run(iters); +} +BENCHMARK(BM_Adagrad)->Arg(128 << 10)->Arg(256 << 10); + +static void Momentum(int32 n, Graph** init_g, Graph** train_g) { + RequireDefaultOps(); + TensorShape shape({n}); + { + Graph* g = new Graph(OpRegistry::Global()); + auto var = Var(g, n); + auto accum = Var(g, n); + auto zero = Zeros(g, n); + test::graph::Assign(g, var, zero); + test::graph::Assign(g, accum, zero); + *init_g = g; + } + { + Graph* g = new Graph(OpRegistry::Global()); + auto var = Var(g, n); + auto accum = Var(g, n); + auto lr = Scalar(g, 0.01); + auto grad = Random(g, n); + auto mom = Scalar(g, 0.01); + test::graph::Multi(g, "ApplyMomentum", {var, accum, lr, grad, mom}); + *train_g = g; + } +} + +static void BM_Momentum(int iters, int params) { + const int64 tot = static_cast(iters) * params; + testing::ItemsProcessed(tot); + testing::BytesProcessed(tot * sizeof(float)); + Graph* init; + Graph* train; + Momentum(params, &init, &train); + test::Benchmark("cpu", train, GetOptions(), init).Run(iters); +} +BENCHMARK(BM_Momentum)->Arg(128 << 10)->Arg(256 << 10); + +static void Adam(int32 n, Graph** init_g, Graph** train_g) { + RequireDefaultOps(); + TensorShape shape({n}); + { + Graph* g = new Graph(OpRegistry::Global()); + auto var = Var(g, n); + auto m = Var(g, n); + auto v = Var(g, n); + auto zero = Zeros(g, n); + test::graph::Assign(g, var, zero); + test::graph::Assign(g, m, zero); + test::graph::Assign(g, v, zero); + *init_g = g; + } + { + Graph* g = new Graph(OpRegistry::Global()); + auto var = Var(g, n); + auto m = Var(g, n); + auto v = Var(g, n); + auto beta1_power = Scalar(g, 0.9); + auto beta2_power = Scalar(g, 0.99); + auto lr = Scalar(g, 0.01); + auto beta1 = Scalar(g, 0.9); + auto beta2 = Scalar(g, 0.99); + auto epsilon = Scalar(g, 1e-8); + auto grad = Random(g, n); + test::graph::Multi(g, "ApplyAdam", {var, m, v, beta1_power, beta2_power, lr, + beta1, beta2, epsilon, grad}); + *train_g = g; + } +} + +static void BM_Adam(int iters, int params) { + const int64 tot = static_cast(iters) * params; + testing::ItemsProcessed(tot); + testing::BytesProcessed(tot * sizeof(float)); + Graph* init; + Graph* train; + Adam(params, &init, &train); + test::Benchmark("cpu", train, GetOptions(), init).Run(iters); +} +BENCHMARK(BM_Adam)->Arg(128 << 10)->Arg(256 << 10); + +static void RMSProp(int32 n, Graph** init_g, Graph** train_g) { + RequireDefaultOps(); + TensorShape shape({n}); + { + Graph* g = new Graph(OpRegistry::Global()); + auto var = Var(g, n); + auto ms = Var(g, n); + auto mom = Var(g, n); + auto zero = Zeros(g, n); + test::graph::Assign(g, var, zero); + test::graph::Assign(g, ms, zero); + test::graph::Assign(g, mom, zero); + *init_g = g; + } + { + Graph* g = new Graph(OpRegistry::Global()); + auto var = Var(g, n); + auto ms = Var(g, n); + auto mom = Var(g, n); + auto lr = Scalar(g, 0.01); + auto rho = Scalar(g, 0.9); + auto momentum = Scalar(g, 0.9); + auto epsilon = Scalar(g, 1e-8); + auto grad = Random(g, n); + test::graph::Multi(g, "ApplyRMSProp", + {var, ms, mom, lr, rho, momentum, epsilon, grad}); + *train_g = g; + } +} + +static void BM_RMSProp(int iters, int params) { + const int64 tot = static_cast(iters) * params; + testing::ItemsProcessed(tot); + testing::BytesProcessed(tot * sizeof(float)); + Graph* init; + Graph* train; + RMSProp(params, &init, &train); + test::Benchmark("cpu", train, GetOptions(), init).Run(iters); +} +BENCHMARK(BM_RMSProp)->Arg(128 << 10)->Arg(256 << 10); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc new file mode 100644 index 0000000000..4f11a881f8 --- /dev/null +++ b/tensorflow/core/kernels/transpose_op.cc @@ -0,0 +1,190 @@ +// See docs in ../ops/array_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/transpose_op.h" +#include "tensorflow/core/kernels/transpose_op_functor.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// inv = InvertPermutationOp(T p) takes a permutation of +// integers 0, 1, ..., n - 1 and returns the inverted +// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n). +// +// REQUIRES: input is a vector of int32. +// REQUIRES: input is a permutation of 0, 1, ..., n-1. + +class InvertPermutationOp : public OpKernel { + public: + explicit InvertPermutationOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(input.shape()), + errors::InvalidArgument("invert_permutation expects a 1D vector.")); + auto Tin = input.vec(); + const int N = Tin.size(); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + auto Tout = output->vec(); + std::fill_n(Tout.data(), N, -1); + for (int i = 0; i < N; ++i) { + const int32 d = Tin(i); + OP_REQUIRES(context, 0 <= d && d < N, + errors::InvalidArgument(d, " is not between 0 and ", N)); + OP_REQUIRES(context, Tout(d) == -1, + errors::InvalidArgument(d, " is duplicated in the input.")); + Tout(d) = i; + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("InvertPermutation").Device(DEVICE_CPU), + InvertPermutationOp); + +// output = TransposeOp(T input, T perm) takes a tensor +// of type T and rank N, and a permutation of 0, 1, ..., N-1. It +// shuffles the dimensions of the input tensor according to permutation. +// +// Specifically, the returned tensor output meets the following condition: +// 1) output.dims() == input.dims(); +// 2) output.dim_size(i) == input.dim_size(perm[i]); +// 3) output.tensor(i_0, i_1, ..., i_N-1) == +// input.tensor(j_0, j_1, ..., j_N-1), +// where i_s == j_{perm[s]} +// +// REQUIRES: perm is a vector of int32. +// REQUIRES: input.dims() == perm.size(). +// REQUIRES: perm is a permutation. + +template +TransposeOp::TransposeOp(OpKernelConstruction* context) + : OpKernel(context) {} + +template +void TransposeOp::Compute(OpKernelContext* context) { + const Tensor& input = context->input(0); + const Tensor& perm = context->input(1); + // Preliminary validation of sizes. + OP_REQUIRES(context, TensorShapeUtils::IsVector(perm.shape()), + errors::InvalidArgument("perm must be a vector, not ", + perm.shape().DebugString())); + auto Vperm = perm.vec(); + const int dims = input.dims(); + static const int kMinDims = 1; + static const int kMaxDims = 8; + OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims, + errors::Unimplemented("Transposing a tensor of rank ", dims, + " is not implemented.")); + OP_REQUIRES(context, dims == Vperm.size(), + errors::InvalidArgument( + "transpose expects a vector of size ", input.dims(), + ". But input(1) is a vector of size ", Vperm.size())); + gtl::ArraySlice permutation( + reinterpret_cast(Vperm.data()), dims); + TensorShape shape; + + // Check whether permutation is a permutation of integers of [0 .. dims). + gtl::InlinedVector bits(dims); + for (const int32 d : permutation) { + OP_REQUIRES( + context, 0 <= d && d < dims, + errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); + bits[d] = true; + shape.AddDim(input.dim_size(d)); + } + for (int i = 0; i < dims; ++i) { + OP_REQUIRES(context, bits[i], errors::InvalidArgument( + i, " is missing from {", + str_util::Join(permutation, ","), "}.")); + } + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output)); + switch (dims) { +#define EXPAND_DIM(N) \ + case N: { \ + functor::TransposeFunctor func; \ + func(context->eigen_device(), output->tensor(), \ + input.tensor(), permutation.data()); \ + break; \ + } + EXPAND_DIM(1); + EXPAND_DIM(2); + EXPAND_DIM(3); + EXPAND_DIM(4); + EXPAND_DIM(5); + EXPAND_DIM(6); + EXPAND_DIM(7); + EXPAND_DIM(8); + default: + LOG(FATAL) << "Unexpected dims: " << dims; + } +#undef EXPAND_CASE +} + +namespace functor { + +template +void TransposeMaybeInline(const Device& d, + typename TTypes::Tensor out, + typename TTypes::ConstTensor in, + const int* perm) { + // perm[] is a permutation of 0, 1, ..., NDIMS-1. perm[] is on CPU. + Eigen::array p; + for (int i = 0; i < NDIMS; ++i) p[i] = perm[i]; + if (out.size() * sizeof(T) < 131072) { // Small transpose on a CPU: do inline + out = in.shuffle(p); + } else { + out.device(d) = in.shuffle(p); + } +} + +template +struct TransposeFunctor { + void operator()(const CPUDevice& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in, const int* perm) { + TransposeMaybeInline(d, out, in, perm); + } +}; + +} // namespace functor + +#define REGISTER(D, T) \ + template class TransposeOp; \ + REGISTER_KERNEL_BUILDER(Name("Transpose") \ + .Device(DEVICE_##D) \ + .TypeConstraint("T") \ + .HostMemory("perm"), \ + TransposeOp) +REGISTER(CPU, float); +REGISTER(CPU, double); +REGISTER(CPU, complex64); +REGISTER(CPU, uint8); +REGISTER(CPU, int8); +REGISTER(CPU, int16); +REGISTER(CPU, int32); +REGISTER(CPU, int64); +REGISTER(CPU, string); +#if GOOGLE_CUDA +REGISTER(GPU, uint8); +REGISTER(GPU, int8); +REGISTER(GPU, int16); +REGISTER(GPU, int32); +REGISTER(GPU, int64); +REGISTER(GPU, float); +REGISTER(GPU, double); +#endif +#undef REGISTER +} // namespace tensorflow diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h new file mode 100644 index 0000000000..f7a5be5c2b --- /dev/null +++ b/tensorflow/core/kernels/transpose_op.h @@ -0,0 +1,19 @@ +#ifndef TENSORFLOW_KERNELS_TRANSPOSE_OP_H_ +#define TENSORFLOW_KERNELS_TRANSPOSE_OP_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +template +class TransposeOp : public OpKernel { + public: + explicit TransposeOp(OpKernelConstruction* context); + void Compute(OpKernelContext* context) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_TRANSPOSE_OP_H_ diff --git a/tensorflow/core/kernels/transpose_op_functor.h b/tensorflow/core/kernels/transpose_op_functor.h new file mode 100644 index 0000000000..8cbd1cbb29 --- /dev/null +++ b/tensorflow/core/kernels/transpose_op_functor.h @@ -0,0 +1,28 @@ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +void Transpose(const Device& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in, const int* perm) { + // perm[] is a permutation of 0, 1, ..., NDIMS-1. perm[] is on CPU. + Eigen::array p; + for (int i = 0; i < NDIMS; ++i) p[i] = perm[i]; + out.device(d) = in.shuffle(p); +} + +template +struct TransposeFunctor { + void operator()(const Device& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in, const int* perm); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_ diff --git a/tensorflow/core/kernels/transpose_op_gpu.cu.cc b/tensorflow/core/kernels/transpose_op_gpu.cu.cc new file mode 100644 index 0000000000..8c04a6544e --- /dev/null +++ b/tensorflow/core/kernels/transpose_op_gpu.cu.cc @@ -0,0 +1,43 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/kernels/transpose_op_functor.h" + +namespace tensorflow { +namespace functor { + +template +struct TransposeFunctor { + void operator()(const Eigen::GpuDevice& d, + typename TTypes::Tensor out, + typename TTypes::ConstTensor in, const int* perm) { + Transpose(d, out, in, perm); + } +}; + +#define DEFINE(T, N) template struct TransposeFunctor; +#define DEFINE_DIM(T) \ + DEFINE(T, 1); \ + DEFINE(T, 2); \ + DEFINE(T, 3); \ + DEFINE(T, 4); \ + DEFINE(T, 5); \ + DEFINE(T, 6); \ + DEFINE(T, 7); \ + DEFINE(T, 8); +DEFINE_DIM(uint8); +DEFINE_DIM(int8); +DEFINE_DIM(int16); +DEFINE_DIM(int32); +DEFINE_DIM(int64); +DEFINE_DIM(float); +DEFINE_DIM(double); +#undef DEFINE_DIM +#undef DEFINE + +} // end namespace functor +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc new file mode 100644 index 0000000000..61f4a54583 --- /dev/null +++ b/tensorflow/core/kernels/unique_op.cc @@ -0,0 +1,61 @@ +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class UniqueOp : public OpKernel { + public: + explicit UniqueOp(OpKernelConstruction* context) : OpKernel(context) { + const DataType dt = DataTypeToEnum::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt, DT_INT32})); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), + errors::InvalidArgument("unique expects a 1D vector.")); + auto Tin = input.vec(); + const int N = Tin.size(); + + Tensor* idx = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, input.shape(), &idx)); + auto idx_vec = idx->template vec(); + + std::unordered_map uniq; + uniq.reserve(2 * N); + for (int i = 0, j = 0; i < N; ++i) { + auto it = uniq.insert(std::make_pair(Tin(i), j)); + idx_vec(i) = it.first->second; + if (it.second) { + ++j; + } + } + int32 uniq_size = uniq.size(); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({uniq_size}), &output)); + auto output_vec = output->template vec(); + + for (auto it : uniq) { + output_vec(it.second) = it.first; + } + } +}; + +#define REGISTER_UNIQUE(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Unique").Device(DEVICE_CPU).TypeConstraint("T"), \ + UniqueOp) + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE); +#undef REGISTER_UNIQUE +} // namespace tensorflow diff --git a/tensorflow/core/kernels/unique_op_test.cc b/tensorflow/core/kernels/unique_op_test.cc new file mode 100644 index 0000000000..658f2282cf --- /dev/null +++ b/tensorflow/core/kernels/unique_op_test.cc @@ -0,0 +1,51 @@ +#include +#include +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { + +namespace { + +static void BM_Unique(int iters, int dim) { + testing::StopTiming(); + RequireDefaultOps(); + Graph* g = new Graph(OpRegistry::Global()); + + Tensor input(DT_INT32, TensorShape({dim})); + input.flat().setRandom(); + + Node* node; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Unique") + .Input(test::graph::Constant(g, input)) + .Attr("T", DT_INT32) + .Finalize(g, &node)); + + testing::BytesProcessed(static_cast(iters) * dim * sizeof(int32)); + testing::UseRealTime(); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} + +BENCHMARK(BM_Unique) + ->Arg(32) + ->Arg(256) + ->Arg(1024) + ->Arg(4 * 1024) + ->Arg(16 * 1024) + ->Arg(64 * 1024) + ->Arg(256 * 1024); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/unpack_op.cc b/tensorflow/core/kernels/unpack_op.cc new file mode 100644 index 0000000000..36cfb2c8e5 --- /dev/null +++ b/tensorflow/core/kernels/unpack_op.cc @@ -0,0 +1,96 @@ +// See docs in ../ops/array_ops.cc. + +#define EIGEN_USE_THREADS + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/split_op.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class UnpackOp : public OpKernel { + public: + explicit UnpackOp(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* context) override { + const int32 num = num_outputs(); + const Tensor& input = context->input(0); + const TensorShape& input_shape = input.shape(); + + OP_REQUIRES( + context, input_shape.dims() > 0 && input_shape.dim_size(0) == num, + errors::InvalidArgument("Input shape must start with ", num, ", got ", + input_shape.ShortDebugString())); + + auto output_shape = input_shape; + output_shape.RemoveDim(0); + const int32 output_size = output_shape.num_elements(); + + // Special case: Aligned, so we can share the underlying buffer. + // + // Apply this optimization conservatively: if input is aligned, + // the resulting tensors must be aligned. It's conservative + // because if the immediate consumer of the resulting tensors are + // not using eigen for computation, its perfectly fine to avoid + // the copying. + if (output_size == 0 || IsInnerDimsSizeAligned(input_shape)) { + for (int i = 0; i < num; ++i) { + Tensor output; + CHECK(output.CopyFrom(input.Slice(i, i + 1), output_shape)); + context->set_output(i, output); + } + return; + } + + // Except for shape, unpack is a special case of split, so we reuse the + // same computational kernels. + auto input_reshaped = input.shaped({1, num, output_size}); + + for (int i = 0; i < num; ++i) { + Tensor* output; + OP_REQUIRES_OK(context, + context->allocate_output(i, output_shape, &output)); + auto output_shaped = output->shaped({1, 1, output_size}); + + Eigen::DSizes indices{0, i, 0}; + Eigen::DSizes sizes{1, 1, output_size}; + functor::Split()(context->eigen_device(), + output_shaped, input_reshaped, indices, + sizes); + } + } +}; + +#define REGISTER_UNPACK(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Unpack").Device(DEVICE_CPU).TypeConstraint("T"), \ + UnpackOp) + +TF_CALL_ALL_TYPES(REGISTER_UNPACK); + +#undef REGISTER_UNPACK + +#if GOOGLE_CUDA + +#define REGISTER_GPU(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Unpack").Device(DEVICE_GPU).TypeConstraint("T"), \ + UnpackOp) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); +#undef REGISTER_GPU + +#endif // GOOGLE_CUDA + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc new file mode 100644 index 0000000000..2f1dbc68c0 --- /dev/null +++ b/tensorflow/core/kernels/variable_ops.cc @@ -0,0 +1,37 @@ +#define EIGEN_USE_THREADS +#include "tensorflow/core/kernels/variable_ops.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp); +REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU), + TemporaryVariableOp); +REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU), + DestroyTemporaryVariableOp); + +#if GOOGLE_CUDA +// Only register 'Variable' on GPU for the subset of types also supported by +// 'Assign' (see dense_update_ops.cc.) +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Variable").Device(DEVICE_GPU).TypeConstraint("dtype"), \ + VariableOp); \ + REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("dtype"), \ + TemporaryVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + DestroyTemporaryVariableOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +#undef REGISTER_GPU_KERNELS +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h new file mode 100644 index 0000000000..77d2da0ad4 --- /dev/null +++ b/tensorflow/core/kernels/variable_ops.h @@ -0,0 +1,146 @@ +#ifndef TENSORFLOW_KERNELS_VARIABLE_OPS_H_ +#define TENSORFLOW_KERNELS_VARIABLE_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class VariableOp : public OpKernel { + public: + explicit VariableOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); + dtype_ = RemoveRefType(context->output_type(0)); + } + + ~VariableOp() override { + if (var_) var_->Unref(); + } + + void Compute(OpKernelContext* ctx) override { + mutex_lock l(init_mu_); + if (var_ == nullptr) { + OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), + true /* use name() */)); + auto creator = [this](Var** var) { + *var = new Var(dtype_); + (*var)->tensor()->set_shape(shape_); + return Status::OK(); + }; + OP_REQUIRES_OK(ctx, + cinfo_.resource_manager()->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &var_, creator)); + } + // Output a reference to our tensor, so it may be updated. + // + // As long as *this is alive, the ref we return here is valid + // because *this owns a ref on var_. + ctx->set_output_ref(0, var_->mu(), var_->tensor()); + } + + private: + class Var : public ResourceBase { + public: + explicit Var(DataType dtype) : tensor_(dtype) {} + mutex* mu() { return &mu_; } + Tensor* tensor() { return &tensor_; } + + string DebugString() override { + return strings::StrCat(DataTypeString(tensor_.dtype()), "/", + tensor_.shape().ShortDebugString()); + } + + private: + mutex mu_; + Tensor tensor_; + + ~Var() override {} + TF_DISALLOW_COPY_AND_ASSIGN(Var); + }; + + DataType dtype_; + TensorShape shape_; + + mutex init_mu_; + ContainerInfo cinfo_ GUARDED_BY(init_mu_); + Var* var_ GUARDED_BY(init_mu_) = nullptr; + + TF_DISALLOW_COPY_AND_ASSIGN(VariableOp); +}; + +class TemporaryVariableOp : public OpKernel { + public: + explicit TemporaryVariableOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); + OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_)); + // Variable name defaults to op name if not specified explicitly. + if (var_name_ == "") var_name_ = name(); + } + + void Compute(OpKernelContext* context) override { + Status s; + ResourceMgr* rm = context->step_resource_manager(); + OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager.")); + auto* tmp_var = new TmpVar; + OP_REQUIRES(context, tmp_var, + errors::ResourceExhausted("Could not allocate TmpVar.")); + tmp_var->name = var_name_; + s = context->allocate_temp(dtype_, shape_, &tmp_var->val); + if (!s.ok()) tmp_var->Unref(); + OP_REQUIRES_OK(context, s); + OP_REQUIRES_OK(context, rm->Create("tmp_var", var_name_, tmp_var)); + context->set_output_ref(0, &tmp_var->mu, &tmp_var->val); + } + + private: + // Refcounted temporary variable resource. + friend class DestroyTemporaryVariableOp; + struct TmpVar : public ResourceBase { + mutex mu; + Tensor val; + string name; + string DebugString() override { return name; } + ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; } + }; + + TensorShape shape_; + DataType dtype_; + string var_name_; +}; + +class DestroyTemporaryVariableOp : public OpKernel { + public: + explicit DestroyTemporaryVariableOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES(context, IsRefType(context->input_type(0)), + errors::InvalidArgument("lhs input needs to be a ref type")) + OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_)); + OP_REQUIRES(context, var_name_ != "", + errors::InvalidArgument("Missing var_name attribute")); + } + + void Compute(OpKernelContext* context) override { + // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed + // their execution before this DestroyTemporaryVariable op executes. + // This is typically achieved using control dependencies. + CHECK(IsRefType(context->input_dtype(0))); + Tensor tmpvar = context->mutable_input(0, false); + context->set_output(0, tmpvar); + ResourceMgr* rm = context->step_resource_manager(); + OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager.")); + OP_REQUIRES_OK( + context, rm->Delete("tmp_var", var_name_)); + } + + private: + string var_name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_VARIABLE_OPS_H_ diff --git a/tensorflow/core/kernels/where_op.cc b/tensorflow/core/kernels/where_op.cc new file mode 100644 index 0000000000..9db0943ea7 --- /dev/null +++ b/tensorflow/core/kernels/where_op.cc @@ -0,0 +1,74 @@ +// See docs in ../ops/array_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/where_op.h" + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class WhereOp : public OpKernel { + public: + explicit WhereOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + + const int input_dims = input.dims(); + Tensor num_true; + OP_REQUIRES_OK( + context, context->allocate_temp(DT_INT64, TensorShape({}), &num_true)); + auto num_true_t = num_true.scalar(); + + functor::NumTrue::Compute(context->eigen_device(), + input.flat(), num_true_t); + TensorShape output_shape({num_true_t(), input_dims}); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + +#define HANDLE_DIM(NDIM) \ + case NDIM: \ + functor::Where::Compute(context->eigen_device(), \ + input.tensor(), \ + output->matrix()); \ + break; + + switch (input_dims) { + HANDLE_DIM(1); + HANDLE_DIM(2); + HANDLE_DIM(3); + HANDLE_DIM(4); + HANDLE_DIM(5); + + default: + OP_REQUIRES(context, false, + errors::InvalidArgument( + "WhereOp : Unhandled input dimensions: ", input_dims)); + } +#undef HANDLE_DIM + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(WhereOp); +}; + +#define REGISTER_WHERE() \ + REGISTER_KERNEL_BUILDER(Name("Where").Device(DEVICE_CPU), WhereOp); + +REGISTER_WHERE(); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/where_op.h b/tensorflow/core/kernels/where_op.h new file mode 100644 index 0000000000..c7b835d02f --- /dev/null +++ b/tensorflow/core/kernels/where_op.h @@ -0,0 +1,65 @@ +#ifndef TENSORFLOW_KERNELS_WHERE_OP_H_ +#define TENSORFLOW_KERNELS_WHERE_OP_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +namespace functor { + +template +struct NumTrue { + EIGEN_ALWAYS_INLINE static void Compute( + const Device& d, typename TTypes::ConstFlat input, + TTypes::Scalar num_true) { + num_true.device(d) = input.template cast().sum(); + } +}; + +template +struct Where { + EIGEN_ALWAYS_INLINE static void Compute( + const Device& d, typename TTypes::ConstTensor input, + typename TTypes::Matrix output) { + Eigen::DenseIndex true_n = 0; + Eigen::DSizes dims = input.dimensions(); + Eigen::DSizes strides; + + // Calculate strides for RowMajor order. + EIGEN_STATIC_ASSERT((static_cast(decltype(input)::Layout) == + static_cast(Eigen::RowMajor)), + INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR); + + strides[NDIM - 1] = 1; + for (int i = NDIM - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * dims[i + 1]; + } + + // Note, no bounds checking is done on true_n. It is assumed that + // the output was correctly sized via output of NumTrue::Compute. + for (Eigen::DenseIndex n = 0; n < input.size(); ++n) { + if (input.data()[n]) { + WriteIndexRowMajor(output, strides, true_n, n); + ++true_n; + } + } + } + + EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor( + typename TTypes::Matrix output, + const Eigen::DSizes& strides, + Eigen::DenseIndex true_n, Eigen::DenseIndex index) { + for (int i = 0; i < NDIM; ++i) { + output(true_n, i) = index / strides[i]; + index %= strides[i]; + } + } +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_WHERE_OP_H_ diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc new file mode 100644 index 0000000000..b940163ec9 --- /dev/null +++ b/tensorflow/core/kernels/whole_file_read_ops.cc @@ -0,0 +1,108 @@ +// See docs in ../ops/io_ops.cc. + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/reader_op_kernel.h" +#include "tensorflow/core/kernels/reader_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +static Status ReadEntireFile(Env* env, const string& filename, + string* contents) { + uint64 file_size = 0; + TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size)); + contents->resize(file_size); + RandomAccessFile* file; + TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file)); + std::unique_ptr make_sure_file_gets_deleted(file); + StringPiece data; + TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(*contents)[0])); + if (data.size() != file_size) { + return errors::DataLoss("Truncated read of '", filename, "' expected ", + file_size, " got ", data.size()); + } + if (data.data() != &(*contents)[0]) { + memmove(&(*contents)[0], data.data(), data.size()); + } + return Status::OK(); +} + +class WholeFileReader : public ReaderBase { + public: + WholeFileReader(Env* env, const string& node_name) + : ReaderBase(strings::StrCat("WholeFileReader '", node_name, "'")), + env_(env) {} + + Status ReadLocked(string* key, string* value, bool* produced, + bool* at_end) override { + *key = current_work(); + TF_RETURN_IF_ERROR(ReadEntireFile(env_, *key, value)); + *produced = true; + *at_end = true; + return Status::OK(); + } + + // Stores state in a ReaderBaseState proto, since WholeFileReader has + // no additional state beyond ReaderBase. + Status SerializeStateLocked(string* state) override { + ReaderBaseState base_state; + SaveBaseState(&base_state); + base_state.SerializeToString(state); + return Status::OK(); + } + + Status RestoreStateLocked(const string& state) override { + ReaderBaseState base_state; + if (!ParseProtoUnlimited(&base_state, state)) { + return errors::InvalidArgument("Could not parse state for ", name(), ": ", + str_util::CEscape(state)); + } + TF_RETURN_IF_ERROR(RestoreBaseState(base_state)); + return Status::OK(); + } + + private: + Env* env_; +}; + +class WholeFileReaderOp : public ReaderOpKernel { + public: + explicit WholeFileReaderOp(OpKernelConstruction* context) + : ReaderOpKernel(context) { + Env* env = context->env(); + SetReaderFactory( + [this, env]() { return new WholeFileReader(env, name()); }); + } +}; + +REGISTER_KERNEL_BUILDER(Name("WholeFileReader").Device(DEVICE_CPU), + WholeFileReaderOp); + +class ReadFileOp : public OpKernel { + public: + using OpKernel::OpKernel; + void Compute(OpKernelContext* context) override { + const Tensor* input; + OP_REQUIRES_OK(context, context->input("filename", &input)); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(input->shape()), + errors::InvalidArgument( + "Input filename tensor must be scalar, but had shape: ", + input->shape().DebugString())); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output("contents", + TensorShape({}), &output)); + OP_REQUIRES_OK(context, + ReadEntireFile(context->env(), input->scalar()(), + &output->scalar()())); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ReadFile").Device(DEVICE_CPU), ReadFileOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc new file mode 100644 index 0000000000..ff54d157af --- /dev/null +++ b/tensorflow/core/kernels/xent_op.cc @@ -0,0 +1,90 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/kernels/xent_op.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class SoftmaxXentWithLogitsOp : public OpKernel { + public: + explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& logits_in = context->input(0); + const Tensor& labels_in = context->input(1); + OP_REQUIRES(context, logits_in.IsSameSize(labels_in), + errors::InvalidArgument( + "logits and labels must be same size: logits_size=", + logits_in.shape().DebugString(), " labels_size=", + labels_in.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()), + errors::InvalidArgument("logits must be 2-dimensional")); + // As we already tested that both inputs have the same shape no need to + // check that "labels" is a matrix too. + + // loss is 1-D (one per example), and size is batch_size. + + Tensor scratch; + OP_REQUIRES_OK( + context, context->allocate_temp(DataTypeToEnum::value, + TensorShape({logits_in.dim_size(0), 1}), + &scratch)); + + Tensor* loss_out = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({logits_in.dim_size(0)}), &loss_out)); + Tensor* back_out = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(1, logits_in.shape(), &back_out)); + + functor::XentFunctor functor; + functor(context->eigen_device(), logits_in.matrix(), + labels_in.matrix(), scratch.matrix(), loss_out->vec(), + back_out->matrix()); + } +}; + +// Partial specialization for a CPUDevice, that uses the Eigen implementation +// from XentEigenImpl. +namespace functor { +template +struct XentFunctor { + void operator()(const CPUDevice& d, typename TTypes::ConstMatrix logits, + typename TTypes::ConstMatrix labels, + typename TTypes::Matrix scratch, + typename TTypes::Vec loss, + typename TTypes::Matrix backprop) { + XentEigenImpl::Compute(d, logits, labels, scratch, loss, + backprop); + } +}; +} // namespace functor + +REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + SoftmaxXentWithLogitsOp); +REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + SoftmaxXentWithLogitsOp); + +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") + .Device(DEVICE_GPU) + .TypeConstraint("T"), + SoftmaxXentWithLogitsOp); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/xent_op.h b/tensorflow/core/kernels/xent_op.h new file mode 100644 index 0000000000..edb7d817c8 --- /dev/null +++ b/tensorflow/core/kernels/xent_op.h @@ -0,0 +1,102 @@ +#ifndef TENSORFLOW_KERNELS_XENT_OP_H_ +#define TENSORFLOW_KERNELS_XENT_OP_H_ +// Functor definition for XentOp, must be compilable by nvcc. + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by XentOp to do the computations. +template +struct XentFunctor { + // Computes Cross Entropy loss and backprop. + // + // logits: batch_size, num_classes. + // labels: batch_size, num_classes. + // scratch: temporary tensor, dims: batch_size, 1 + // loss: output tensor for the loss, dims: batch_size. + // backprop: output tensor for the backprop, dims: batch_size, num_classes. + void operator()(const Device& d, typename TTypes::ConstMatrix logits, + typename TTypes::ConstMatrix labels, + typename TTypes::Matrix scratch, + typename TTypes::Vec loss, + typename TTypes::Matrix backprop); +}; + +// Eigen code implementing XentFunctor::operator(). +// This code works for both CPU and GPU and is used by the functor +// specializations for both device types. +template +struct XentEigenImpl { + static void Compute(const Device& d, typename TTypes::ConstMatrix logits, + typename TTypes::ConstMatrix labels, + typename TTypes::Matrix scratch, + typename TTypes::Vec loss, + typename TTypes::Matrix backprop) { + // NOTE(mdevin): This duplicates some of the computations in softmax_op + // because we need the intermediate (logits -max(logits)) values to + // avoid a log(exp()) in the computation of the loss. + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + +// These arrays are used to reduce along the class dimension, and broadcast +// the resulting value to all classes. +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::array along_class; + along_class[0] = kClassDim; + Eigen::array batch_only; + batch_only[0] = batch_size; + Eigen::array batch_by_one; + batch_by_one[0] = batch_size; + batch_by_one[1] = 1; + Eigen::array one_by_class; + one_by_class[0] = 1; + one_by_class[1] = num_classes; +#else + Eigen::IndexList > along_class; + Eigen::IndexList > batch_by_one; + batch_by_one.set(0, batch_size); + Eigen::IndexList batch_only; + batch_only.set(0, batch_size); + Eigen::IndexList, int> one_by_class; + one_by_class.set(1, num_classes); +#endif + + // max_logits along classes. + scratch.reshape(batch_only).device(d) = logits.maximum(along_class); + + // logits - max_logits. + backprop.device(d) = logits - scratch.broadcast(one_by_class); + + // sum(exp(logits - max_logits)) along classes. + scratch.reshape(batch_only).device(d) = backprop.exp().sum(along_class); + + // NOTE(keveman): Eigen on GPU dispatches to an optimized implementaion + // for an expression of the form lhs = rhs.sum(). + // lhs = -rhs.sum() doesn't match the above pattern, so folding in the + // negation before calling sum(). + // sum(-labels * + // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) + // along classes + loss.device(d) = + (labels * (scratch.log().eval().broadcast(one_by_class) - backprop)) + .eval() + .sum(along_class); + + // backprop: prob - labels, where + // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) + backprop.device(d) = + (backprop.exp() / scratch.broadcast(one_by_class)) - labels; + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_XENT_OP_H_ diff --git a/tensorflow/core/kernels/xent_op_gpu.cu.cc b/tensorflow/core/kernels/xent_op_gpu.cu.cc new file mode 100644 index 0000000000..eec6a84281 --- /dev/null +++ b/tensorflow/core/kernels/xent_op_gpu.cu.cc @@ -0,0 +1,35 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/xent_op.h" + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// Partial specialization for a GPUDevice, that uses the Eigen implementation +// from XentEigenImpl. +namespace functor { +template +struct XentFunctor { + void operator()(const GPUDevice& d, typename TTypes::ConstMatrix logits, + typename TTypes::ConstMatrix labels, + typename TTypes::Matrix scratch, + typename TTypes::Vec loss, + typename TTypes::Matrix backprop) { + XentEigenImpl::Compute(d, logits, labels, scratch, loss, + backprop); + } +}; +} // end namespace functor + +// Instantiate the GPU implementation for float. +template struct functor::XentFunctor; + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/xent_op_test.cc b/tensorflow/core/kernels/xent_op_test.cc new file mode 100644 index 0000000000..9aab1b09bf --- /dev/null +++ b/tensorflow/core/kernels/xent_op_test.cc @@ -0,0 +1,46 @@ +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include +#include "tensorflow/core/kernels/xent_op.h" + +namespace tensorflow { + +static Graph* Xent(int batch_size, int num_classes) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor logits(DT_FLOAT, TensorShape({batch_size, num_classes})); + logits.flat().setRandom(); + Tensor labels(DT_FLOAT, TensorShape({batch_size, num_classes})); + labels.flat().setRandom(); + test::graph::Binary(g, "SoftmaxCrossEntropyWithLogits", + test::graph::Constant(g, logits), + test::graph::Constant(g, labels)); + return g; +} + +#define BM_XentDev(BATCH, CLASS, DEVICE) \ + static void BM_Xent##_##BATCH##_##CLASS##_##DEVICE(int iters) { \ + testing::ItemsProcessed(static_cast(iters) * BATCH * CLASS); \ + test::Benchmark(#DEVICE, Xent(BATCH, CLASS)).Run(iters); \ + } \ + BENCHMARK(BM_Xent##_##BATCH##_##CLASS##_##DEVICE); + +/// The representative tests for ptb_word on GPU +BM_XentDev(16, 10000, gpu); +BM_XentDev(16, 30000, gpu); +BM_XentDev(16, 100000, gpu); + +BM_XentDev(32, 10000, gpu); +BM_XentDev(32, 30000, gpu); +BM_XentDev(32, 100000, gpu); + +BM_XentDev(64, 10000, gpu); +BM_XentDev(64, 30000, gpu); +BM_XentDev(64, 100000, gpu); + +/// Only the smaller tests for CPU. Otherwise, it's too slow +BM_XentDev(16, 10000, cpu); +BM_XentDev(32, 10000, cpu); +BM_XentDev(64, 10000, cpu); + +} // end namespace tensorflow diff --git a/tensorflow/core/lib/core/arena.cc b/tensorflow/core/lib/core/arena.cc new file mode 100644 index 0000000000..ceb1001af0 --- /dev/null +++ b/tensorflow/core/lib/core/arena.cc @@ -0,0 +1,246 @@ +// This approach to arenas overcomes many of the limitations described +// in the "Specialized allocators" section of +// http://www.pdos.lcs.mit.edu/~dm/c++-new.html +// +// A somewhat similar approach to Gladiator, but for heap-detection, was +// suggested by Ron van der Wal and Scott Meyers at +// http://www.aristeia.com/BookErrata/M27Comments_frames.html + +#include "tensorflow/core/lib/core/arena.h" + +#include +#include + +#include + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { +namespace core { + +static const int kPageSize = getpagesize(); + +// ---------------------------------------------------------------------- +// Arena::Arena() +// Arena::~Arena() +// Destroying the arena automatically calls Reset() +// ---------------------------------------------------------------------- + +Arena::Arena(const size_t block_size) + : remaining_(0), + block_size_(block_size), + freestart_(NULL), // set for real in Reset() + blocks_alloced_(1), + overflow_blocks_(NULL) { + assert(block_size > kDefaultAlignment); + + first_blocks_[0].mem = reinterpret_cast(malloc(block_size_)); + + first_blocks_[0].size = block_size_; + + Reset(); +} + +Arena::~Arena() { + FreeBlocks(); + assert(overflow_blocks_ == NULL); // FreeBlocks() should do that + // The first X blocks stay allocated always by default. Delete them now. + for (size_t i = 0; i < blocks_alloced_; ++i) free(first_blocks_[i].mem); +} + +// Returns true iff it advances freestart_ to the first position +// satisfying alignment without exhausting the current block. +bool Arena::SatisfyAlignment(size_t alignment) { + const size_t overage = reinterpret_cast(freestart_) & (alignment - 1); + if (overage > 0) { + const size_t waste = alignment - overage; + if (waste >= remaining_) { + return false; + } + freestart_ += waste; + remaining_ -= waste; + } + DCHECK_EQ(0, reinterpret_cast(freestart_) & (alignment - 1)); + return true; +} + +// ---------------------------------------------------------------------- +// Arena::Reset() +// Clears all the memory an arena is using. +// ---------------------------------------------------------------------- + +void Arena::Reset() { + FreeBlocks(); + freestart_ = first_blocks_[0].mem; + remaining_ = first_blocks_[0].size; + + // There is no guarantee the first block is properly aligned, so + // enforce that now. + CHECK(SatisfyAlignment(kDefaultAlignment)); + + freestart_when_empty_ = freestart_; +} + +// ---------------------------------------------------------------------- +// Arena::MakeNewBlock() +// Our sbrk() equivalent. We always make blocks of the same size +// (though GetMemory() can also make a new block for really big +// data. +// ---------------------------------------------------------------------- + +void Arena::MakeNewBlock(const uint32 alignment) { + AllocatedBlock* block = AllocNewBlock(block_size_, alignment); + freestart_ = block->mem; + remaining_ = block->size; + CHECK(SatisfyAlignment(alignment)); +} + +// The following simple numeric routines also exist in util/math/mathutil.h +// but we don't want to depend on that library. + +// Euclid's algorithm for Greatest Common Denominator. +static uint32 GCD(uint32 x, uint32 y) { + while (y != 0) { + uint32 r = x % y; + x = y; + y = r; + } + return x; +} + +static uint32 LeastCommonMultiple(uint32 a, uint32 b) { + if (a > b) { + return (a / GCD(a, b)) * b; + } else if (a < b) { + return (b / GCD(b, a)) * a; + } else { + return a; + } +} + +// ------------------------------------------------------------- +// Arena::AllocNewBlock() +// Adds and returns an AllocatedBlock. +// The returned AllocatedBlock* is valid until the next call +// to AllocNewBlock or Reset. (i.e. anything that might +// affect overflow_blocks_). +// ------------------------------------------------------------- + +Arena::AllocatedBlock* Arena::AllocNewBlock(const size_t block_size, + const uint32 alignment) { + AllocatedBlock* block; + // Find the next block. + if (blocks_alloced_ < TF_ARRAYSIZE(first_blocks_)) { + // Use one of the pre-allocated blocks + block = &first_blocks_[blocks_alloced_++]; + } else { // oops, out of space, move to the vector + if (overflow_blocks_ == NULL) + overflow_blocks_ = new std::vector; + // Adds another block to the vector. + overflow_blocks_->resize(overflow_blocks_->size() + 1); + // block points to the last block of the vector. + block = &overflow_blocks_->back(); + } + + // NOTE(tucker): this utility is made slightly more complex by + // not disallowing the case where alignment > block_size. + // Can we, without breaking existing code? + + // Must be a multiple of kDefaultAlignment, unless requested + // alignment is 1, in which case we don't care at all. + const uint32 adjusted_alignment = + (alignment > 1 ? LeastCommonMultiple(alignment, kDefaultAlignment) : 1); + + CHECK_LE(adjusted_alignment, 1 << 20) + << "Alignment on boundaries greater than 1MB not supported."; + + // If block_size > alignment we force block_size to be a multiple + // of alignment; if block_size < alignment we make no adjustment. + size_t adjusted_block_size = block_size; + if (adjusted_alignment > 1) { + if (adjusted_block_size > adjusted_alignment) { + const uint32 excess = adjusted_block_size % adjusted_alignment; + adjusted_block_size += (excess > 0 ? adjusted_alignment - excess : 0); + } + block->mem = reinterpret_cast( + port::aligned_malloc(adjusted_block_size, adjusted_alignment)); + } else { + block->mem = reinterpret_cast(malloc(adjusted_block_size)); + } + block->size = adjusted_block_size; + CHECK(NULL != block->mem) << "block_size=" << block_size + << " adjusted_block_size=" << adjusted_block_size + << " alignment=" << alignment + << " adjusted_alignment=" << adjusted_alignment; + + return block; +} + +// ---------------------------------------------------------------------- +// Arena::GetMemoryFallback() +// We take memory out of our pool, aligned on the byte boundary +// requested. If we don't have space in our current pool, we +// allocate a new block (wasting the remaining space in the +// current block) and give you that. If your memory needs are +// too big for a single block, we make a special your-memory-only +// allocation -- this is equivalent to not using the arena at all. +// ---------------------------------------------------------------------- + +void* Arena::GetMemoryFallback(const size_t size, const int alignment) { + if (0 == size) { + return NULL; // stl/stl_alloc.h says this is okay + } + + // alignment must be a positive power of 2. + CHECK(alignment > 0 && 0 == (alignment & (alignment - 1))); + + // If the object is more than a quarter of the block size, allocate + // it separately to avoid wasting too much space in leftover bytes. + if (block_size_ == 0 || size > block_size_ / 4) { + return AllocNewBlock(size, alignment)->mem; + } + + // Enforce alignment on freestart_ then check for adequate space, + // which may require starting a new block. + if (!SatisfyAlignment(alignment) || size > remaining_) { + MakeNewBlock(alignment); + } + CHECK_LE(size, remaining_); + + remaining_ -= size; + void* result = freestart_; + freestart_ += size; + + return result; +} + +// ---------------------------------------------------------------------- +// Arena::ReturnMemoryFallback() +// Arena::FreeBlocks() +// Unlike GetMemory(), which does actual work, ReturnMemory() is a +// no-op: we don't "free" memory until Reset() is called. We do +// update some stats, though. Note we do no checking that the +// pointer you pass in was actually allocated by us, or that it +// was allocated for the size you say, so be careful here! +// FreeBlocks() does the work for Reset(), actually freeing all +// memory allocated in one fell swoop. +// ---------------------------------------------------------------------- + +void Arena::FreeBlocks() { + for (size_t i = 1; i < blocks_alloced_; ++i) { // keep first block alloced + free(first_blocks_[i].mem); + first_blocks_[i].mem = NULL; + first_blocks_[i].size = 0; + } + blocks_alloced_ = 1; + if (overflow_blocks_ != NULL) { + std::vector::iterator it; + for (it = overflow_blocks_->begin(); it != overflow_blocks_->end(); ++it) { + free(it->mem); + } + delete overflow_blocks_; // These should be used very rarely + overflow_blocks_ = NULL; + } +} + +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/arena.h b/tensorflow/core/lib/core/arena.h new file mode 100644 index 0000000000..59896803bb --- /dev/null +++ b/tensorflow/core/lib/core/arena.h @@ -0,0 +1,90 @@ +// TODO(vrv): Switch this to an open-sourced version of Arena. + +#ifndef TENSORFLOW_LIB_CORE_ARENA_H_ +#define TENSORFLOW_LIB_CORE_ARENA_H_ + +#include + +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace core { + +// This class is "thread-compatible": different threads can access the +// arena at the same time without locking, as long as they use only +// const methods. +class Arena { + public: + // Allocates a thread-compatible arena with the specified block size. + explicit Arena(const size_t block_size); + ~Arena(); + + char* Alloc(const size_t size) { + return reinterpret_cast(GetMemory(size, 1)); + } + + void Reset(); + +// This should be the worst-case alignment for any type. This is +// good for IA-32, SPARC version 7 (the last one I know), and +// supposedly Alpha. i386 would be more time-efficient with a +// default alignment of 8, but ::operator new() uses alignment of 4, +// and an assertion will fail below after the call to MakeNewBlock() +// if you try to use a larger alignment. +#ifdef __i386__ + static const int kDefaultAlignment = 4; +#else + static const int kDefaultAlignment = 8; +#endif + + protected: + bool SatisfyAlignment(const size_t alignment); + void MakeNewBlock(const uint32 alignment); + void* GetMemoryFallback(const size_t size, const int align); + void* GetMemory(const size_t size, const int align) { + assert(remaining_ <= block_size_); // an invariant + if (size > 0 && size < remaining_ && align == 1) { // common case + void* result = freestart_; + freestart_ += size; + remaining_ -= size; + return result; + } + return GetMemoryFallback(size, align); + } + + size_t remaining_; + + private: + struct AllocatedBlock { + char* mem; + size_t size; + }; + + // Allocate new new block of at least block_size, with the specified + // alignment. + // The returned AllocatedBlock* is valid until the next call to AllocNewBlock + // or Reset (i.e. anything that might affect overflow_blocks_). + AllocatedBlock* AllocNewBlock(const size_t block_size, + const uint32 alignment); + + const size_t block_size_; + char* freestart_; // beginning of the free space in most recent block + char* freestart_when_empty_; // beginning of the free space when we're empty + // STL vector isn't as efficient as it could be, so we use an array at first + size_t blocks_alloced_; // how many of the first_blocks_ have been alloced + AllocatedBlock first_blocks_[16]; // the length of this array is arbitrary + // if the first_blocks_ aren't enough, expand into overflow_blocks_. + std::vector* overflow_blocks_; + + void FreeBlocks(); // Frees all except first block + + TF_DISALLOW_COPY_AND_ASSIGN(Arena); +}; + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_ARENA_H_ diff --git a/tensorflow/core/lib/core/arena_test.cc b/tensorflow/core/lib/core/arena_test.cc new file mode 100644 index 0000000000..fa147c3014 --- /dev/null +++ b/tensorflow/core/lib/core/arena_test.cc @@ -0,0 +1,92 @@ +#include "tensorflow/core/lib/core/arena.h" + +#include + +namespace tensorflow { +namespace core { +namespace { + +// Write random data to allocated memory +static void TestMemory(void* mem, int size) { + // Check that we can memset the entire memory + memset(mem, 0xaa, size); + + // Do some memory allocation to check that the arena doesn't mess up + // the internal memory allocator + char* tmp[100]; + for (size_t i = 0; i < TF_ARRAYSIZE(tmp); i++) { + tmp[i] = new char[i * i + 1]; + } + + memset(mem, 0xcc, size); + + // Free up the allocated memory; + for (size_t i = 0; i < TF_ARRAYSIZE(tmp); i++) { + delete[] tmp[i]; + } + + // Check that we can memset the entire memory + memset(mem, 0xee, size); +} + +TEST(ArenaTest, TestBasicArena) { + Arena a(1024); + char* memory = a.Alloc(100); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 100); + + // Allocate again + memory = a.Alloc(100); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 100); +} + +TEST(ArenaTest, TestVariousArenaSizes) { + { + Arena a(1024); + + // Allocate blocksize + char* memory = a.Alloc(1024); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 1024); + + // Allocate another blocksize + char* memory2 = a.Alloc(1024); + ASSERT_NE(memory2, nullptr); + TestMemory(memory2, 1024); + } + + // Allocate an arena and allocate two blocks + // that together exceed a block size + { + Arena a(1024); + + // + char* memory = a.Alloc(768); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 768); + + // Allocate another blocksize + char* memory2 = a.Alloc(768); + ASSERT_NE(memory2, nullptr); + TestMemory(memory2, 768); + } + + // Allocate larger than a blocksize + { + Arena a(1024); + + char* memory = a.Alloc(10240); + ASSERT_NE(memory, nullptr); + TestMemory(memory, 10240); + + // Allocate another blocksize + char* memory2 = a.Alloc(1234); + ASSERT_NE(memory2, nullptr); + TestMemory(memory2, 1234); + } +} + +} // namespace +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/bit_cast_test.cc b/tensorflow/core/lib/core/bit_cast_test.cc new file mode 100644 index 0000000000..0ea583e96f --- /dev/null +++ b/tensorflow/core/lib/core/bit_cast_test.cc @@ -0,0 +1,95 @@ +// Unit test for bit_cast template. + +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/platform/logging.h" +#include + +namespace tensorflow { + +// Marshall and unmarshall. +// ISO spec C++ section 3.9 promises this will work. + +template +struct marshall { + char buf[N]; +}; + +template +void TestMarshall(const T values[], int num_values) { + for (int i = 0; i < num_values; ++i) { + T t0 = values[i]; + marshall m0 = bit_cast >(t0); + T t1 = bit_cast(m0); + marshall m1 = bit_cast >(t1); + ASSERT_EQ(0, memcmp(&t0, &t1, sizeof(T))); + ASSERT_EQ(0, memcmp(&m0, &m1, sizeof(T))); + } +} + +// Convert back and forth to an integral type. The C++ standard does +// not guarantee this will work. +// +// There are implicit assumptions about sizeof(float) and +// sizeof(double). These assumptions are quite extant everywhere. + +template +void TestIntegral(const T values[], int num_values) { + for (int i = 0; i < num_values; ++i) { + T t0 = values[i]; + I i0 = bit_cast(t0); + T t1 = bit_cast(i0); + I i1 = bit_cast(t1); + ASSERT_EQ(0, memcmp(&t0, &t1, sizeof(T))); + ASSERT_EQ(i0, i1); + } +} + +TEST(BitCast, Bool) { + LOG(INFO) << "Test bool"; + static const bool bool_list[] = {false, true}; + TestMarshall(bool_list, TF_ARRAYSIZE(bool_list)); +} + +TEST(BitCast, Int32) { + static const int32 int_list[] = {0, 1, 100, 2147483647, + -1, -100, -2147483647, -2147483647 - 1}; + TestMarshall(int_list, TF_ARRAYSIZE(int_list)); +} + +TEST(BitCast, Int64) { + static const int64 int64_list[] = {0, 1, 1LL << 40, -1, -(1LL << 40)}; + TestMarshall(int64_list, TF_ARRAYSIZE(int64_list)); +} + +TEST(BitCast, Uint64) { + static const uint64 uint64_list[] = {0, 1, 1LLU << 40, 1LLU << 63}; + TestMarshall(uint64_list, TF_ARRAYSIZE(uint64_list)); +} + +TEST(BitCast, Float) { + static const float float_list[] = {0.0, 1.0, -1.0, 10.0, -10.0, 1e10, + 1e20, 1e-10, 1e-20, 2.71828, 3.14159}; + TestMarshall(float_list, TF_ARRAYSIZE(float_list)); + TestIntegral(float_list, TF_ARRAYSIZE(float_list)); + TestIntegral(float_list, TF_ARRAYSIZE(float_list)); +} + +TEST(BitCast, Double) { + static const double double_list[] = { + 0.0, + 1.0, + -1.0, + 10.0, + -10.0, + 1e10, + 1e100, + 1e-10, + 1e-100, + 2.718281828459045, + 3.141592653589793238462643383279502884197169399375105820974944}; + TestMarshall(double_list, TF_ARRAYSIZE(double_list)); + TestIntegral(double_list, TF_ARRAYSIZE(double_list)); + TestIntegral(double_list, TF_ARRAYSIZE(double_list)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/bits.h b/tensorflow/core/lib/core/bits.h new file mode 100644 index 0000000000..5456a63168 --- /dev/null +++ b/tensorflow/core/lib/core/bits.h @@ -0,0 +1,84 @@ +#ifndef TENSORFLOW_LIB_CORE_BITS_H_ +#define TENSORFLOW_LIB_CORE_BITS_H_ + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +int Log2Floor(uint32 n); +int Log2Floor64(uint64 n); + +// Return ceiling(log2(n)) for positive integer n. Returns -1 iff n == 0. +int Log2Ceiling(uint32 n); +int Log2Ceiling64(uint64 n); + +// ------------------------------------------------------------------------ +// Implementation details follow +// ------------------------------------------------------------------------ + +#if defined(__GNUC__) + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +inline int Log2Floor(uint32 n) { + return n == 0 ? -1 : 31 ^ __builtin_clz(n); +} + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +inline int Log2Floor64(uint64 n) { + return n == 0 ? -1 : 63 ^ __builtin_clzll(n); +} + +#else + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +inline int Log2Floor(uint32 n) { + if (n == 0) + return -1; + int log = 0; + uint32 value = n; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32 x = value >> shift; + if (x != 0) { + value = x; + log += shift; + } + } + assert(value == 1); + return log; +} + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +// Log2Floor64() is defined in terms of Log2Floor32() +inline int Log2Floor64(uint64 n) { + const uint32 topbits = static_cast(n >> 32); + if (topbits == 0) { + // Top bits are zero, so scan in bottom bits + return Log2Floor(static_cast(n)); + } else { + return 32 + Log2Floor(topbits); + } +} + +#endif + +inline int Log2Ceiling(uint32 n) { + int floor = Log2Floor(n); + if (n == (n & ~(n - 1))) // zero or a power of two + return floor; + else + return floor + 1; +} + +inline int Log2Ceiling64(uint64 n) { + int floor = Log2Floor64(n); + if (n == (n & ~(n - 1))) // zero or a power of two + return floor; + else + return floor + 1; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_BITS_H_ diff --git a/tensorflow/core/lib/core/blocking_counter.h b/tensorflow/core/lib/core/blocking_counter.h new file mode 100644 index 0000000000..f141be2c76 --- /dev/null +++ b/tensorflow/core/lib/core/blocking_counter.h @@ -0,0 +1,41 @@ +#ifndef TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_ +#define TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class BlockingCounter { + public: + BlockingCounter(int initial_count) : count_(initial_count) { + CHECK_GE(count_, 0); + } + + ~BlockingCounter() {} + + inline void DecrementCount() { + mutex_lock l(mu_); + --count_; + CHECK(count_ >= 0); + if (count_ == 0) { + cond_var_.notify_all(); + } + } + + inline void Wait() { + mutex_lock l(mu_); + while (count_ > 0) { + cond_var_.wait(l); + } + } + + private: + int count_; + mutex mu_; + condition_variable cond_var_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_ diff --git a/tensorflow/core/lib/core/blocking_counter_test.cc b/tensorflow/core/lib/core/blocking_counter_test.cc new file mode 100644 index 0000000000..feb0342086 --- /dev/null +++ b/tensorflow/core/lib/core/blocking_counter_test.cc @@ -0,0 +1,36 @@ +#include + +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { +namespace { + +TEST(BlockingCounterTest, TestZero) { + BlockingCounter bc(0); + bc.Wait(); +} + +TEST(BlockingCounterTest, TestSingleThread) { + BlockingCounter bc(2); + bc.DecrementCount(); + bc.DecrementCount(); + bc.Wait(); +} + +TEST(BlockingCounterTest, TestMultipleThread) { + int N = 3; + thread::ThreadPool* thread_pool = + new thread::ThreadPool(Env::Default(), "test", N); + + BlockingCounter bc(N); + for (int i = 0; i < N; ++i) { + thread_pool->Schedule([&bc] { bc.DecrementCount(); }); + } + + bc.Wait(); + delete thread_pool; +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/casts.h b/tensorflow/core/lib/core/casts.h new file mode 100644 index 0000000000..5b72048ac5 --- /dev/null +++ b/tensorflow/core/lib/core/casts.h @@ -0,0 +1,85 @@ +// Various Google-specific casting templates. +// +// This code is compiled directly on many platforms, including client +// platforms like Windows, Mac, and embedded systems. Before making +// any changes here, make sure that you're not breaking any platforms. +// + +#ifndef TENSORFLOW_LIB_CORE_CASTS_H_ +#define TENSORFLOW_LIB_CORE_CASTS_H_ + +#include // for memcpy + +namespace tensorflow { + +// bit_cast is a template function that implements the +// equivalent of "*reinterpret_cast(&source)". We need this in +// very low-level functions like the protobuf library and fast math +// support. +// +// float f = 3.14159265358979; +// int i = bit_cast(f); +// // i = 0x40490fdb +// +// The classical address-casting method is: +// +// // WRONG +// float f = 3.14159265358979; // WRONG +// int i = * reinterpret_cast(&f); // WRONG +// +// The address-casting method actually produces undefined behavior +// according to ISO C++ specification section 3.10 -15 -. Roughly, this +// section says: if an object in memory has one type, and a program +// accesses it with a different type, then the result is undefined +// behavior for most values of "different type". +// +// This is true for any cast syntax, either *(int*)&f or +// *reinterpret_cast(&f). And it is particularly true for +// conversions between integral lvalues and floating-point lvalues. +// +// The purpose of 3.10 -15- is to allow optimizing compilers to assume +// that expressions with different types refer to different memory. gcc +// 4.0.1 has an optimizer that takes advantage of this. So a +// non-conforming program quietly produces wildly incorrect output. +// +// The problem is not the use of reinterpret_cast. The problem is type +// punning: holding an object in memory of one type and reading its bits +// back using a different type. +// +// The C++ standard is more subtle and complex than this, but that +// is the basic idea. +// +// Anyways ... +// +// bit_cast<> calls memcpy() which is blessed by the standard, +// especially by the example in section 3.9 . Also, of course, +// bit_cast<> wraps up the nasty logic in one place. +// +// Fortunately memcpy() is very fast. In optimized mode, with a +// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline +// code with the minimal amount of data movement. On a 32-bit system, +// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8) +// compiles to two loads and two stores. +// +// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1. +// +// WARNING: if Dest or Source is a non-POD type, the result of the memcpy +// is likely to surprise you. +// +// Props to Bill Gibbons for the compile time assertion technique and +// Art Komninos and Igor Tandetnik for the msvc experiments. +// +// -- mec 2005-10-17 + +template +inline Dest bit_cast(const Source& source) { + static_assert(sizeof(Dest) == sizeof(Source), "Sizes do not match"); + + Dest dest; + memcpy(&dest, &source, sizeof(dest)); + return dest; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_CASTS_H_ diff --git a/tensorflow/core/lib/core/coding.cc b/tensorflow/core/lib/core/coding.cc new file mode 100644 index 0000000000..efff554742 --- /dev/null +++ b/tensorflow/core/lib/core/coding.cc @@ -0,0 +1,164 @@ +#include "tensorflow/core/lib/core/coding.h" + +namespace tensorflow { +namespace core { + +void EncodeFixed32(char* buf, uint32 value) { + if (port::kLittleEndian) { + memcpy(buf, &value, sizeof(value)); + } else { + buf[0] = value & 0xff; + buf[1] = (value >> 8) & 0xff; + buf[2] = (value >> 16) & 0xff; + buf[3] = (value >> 24) & 0xff; + } +} + +void EncodeFixed64(char* buf, uint64 value) { + if (port::kLittleEndian) { + memcpy(buf, &value, sizeof(value)); + } else { + buf[0] = value & 0xff; + buf[1] = (value >> 8) & 0xff; + buf[2] = (value >> 16) & 0xff; + buf[3] = (value >> 24) & 0xff; + buf[4] = (value >> 32) & 0xff; + buf[5] = (value >> 40) & 0xff; + buf[6] = (value >> 48) & 0xff; + buf[7] = (value >> 56) & 0xff; + } +} + +void PutFixed32(string* dst, uint32 value) { + char buf[sizeof(value)]; + EncodeFixed32(buf, value); + dst->append(buf, sizeof(buf)); +} + +void PutFixed64(string* dst, uint64 value) { + char buf[sizeof(value)]; + EncodeFixed64(buf, value); + dst->append(buf, sizeof(buf)); +} + +char* EncodeVarint32(char* dst, uint32 v) { + // Operate on characters as unsigneds + unsigned char* ptr = reinterpret_cast(dst); + static const int B = 128; + if (v < (1 << 7)) { + *(ptr++) = v; + } else if (v < (1 << 14)) { + *(ptr++) = v | B; + *(ptr++) = v >> 7; + } else if (v < (1 << 21)) { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = v >> 14; + } else if (v < (1 << 28)) { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = (v >> 14) | B; + *(ptr++) = v >> 21; + } else { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = (v >> 14) | B; + *(ptr++) = (v >> 21) | B; + *(ptr++) = v >> 28; + } + return reinterpret_cast(ptr); +} + +void PutVarint32(string* dst, uint32 v) { + char buf[5]; + char* ptr = EncodeVarint32(buf, v); + dst->append(buf, ptr - buf); +} + +char* EncodeVarint64(char* dst, uint64 v) { + static const int B = 128; + unsigned char* ptr = reinterpret_cast(dst); + while (v >= B) { + *(ptr++) = (v & (B - 1)) | B; + v >>= 7; + } + *(ptr++) = static_cast(v); + return reinterpret_cast(ptr); +} + +void PutVarint64(string* dst, uint64 v) { + char buf[10]; + char* ptr = EncodeVarint64(buf, v); + dst->append(buf, ptr - buf); +} + +int VarintLength(uint64_t v) { + int len = 1; + while (v >= 128) { + v >>= 7; + len++; + } + return len; +} + +const char* GetVarint32PtrFallback(const char* p, const char* limit, + uint32* value) { + uint32 result = 0; + for (uint32 shift = 0; shift <= 28 && p < limit; shift += 7) { + uint32 byte = *(reinterpret_cast(p)); + p++; + if (byte & 128) { + // More bytes are present + result |= ((byte & 127) << shift); + } else { + result |= (byte << shift); + *value = result; + return reinterpret_cast(p); + } + } + return NULL; +} + +bool GetVarint32(StringPiece* input, uint32* value) { + const char* p = input->data(); + const char* limit = p + input->size(); + const char* q = GetVarint32Ptr(p, limit, value); + if (q == NULL) { + return false; + } else { + *input = StringPiece(q, limit - q); + return true; + } +} + +const char* GetVarint64Ptr(const char* p, const char* limit, uint64* value) { + uint64 result = 0; + for (uint32 shift = 0; shift <= 63 && p < limit; shift += 7) { + uint64 byte = *(reinterpret_cast(p)); + p++; + if (byte & 128) { + // More bytes are present + result |= ((byte & 127) << shift); + } else { + result |= (byte << shift); + *value = result; + return reinterpret_cast(p); + } + } + return NULL; +} + +bool GetVarint64(StringPiece* input, uint64* value) { + const char* p = input->data(); + const char* limit = p + input->size(); + const char* q = GetVarint64Ptr(p, limit, value); + if (q == NULL) { + return false; + } else { + *input = StringPiece(q, limit - q); + return true; + } +} + +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/coding.h b/tensorflow/core/lib/core/coding.h new file mode 100644 index 0000000000..0c14bf1bbf --- /dev/null +++ b/tensorflow/core/lib/core/coding.h @@ -0,0 +1,55 @@ +// Endian-neutral encoding: +// * Fixed-length numbers are encoded with least-significant byte first +// * In addition we support variable length "varint" encoding +// * Strings are encoded prefixed by their length in varint format + +#ifndef TENSORFLOW_LIB_CORE_CODING_H_ +#define TENSORFLOW_LIB_CORE_CODING_H_ + +#include "tensorflow/core/lib/core/raw_coding.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace core { + +// Lower-level versions of Put... that write directly into a character buffer +// REQUIRES: dst has enough space for the value being written +extern void EncodeFixed32(char* dst, uint32 value); +extern void EncodeFixed64(char* dst, uint64 value); +extern void PutFixed32(string* dst, uint32 value); +extern void PutFixed64(string* dst, uint64 value); + +extern void PutVarint32(string* dst, uint32 value); +extern void PutVarint64(string* dst, uint64 value); + +extern bool GetVarint32(StringPiece* input, uint32* value); +extern bool GetVarint64(StringPiece* input, uint64* value); + +extern const char* GetVarint32Ptr(const char* p, const char* limit, uint32* v); +extern const char* GetVarint64Ptr(const char* p, const char* limit, uint64* v); + +// Internal routine for use by fallback path of GetVarint32Ptr +extern const char* GetVarint32PtrFallback(const char* p, const char* limit, + uint32* value); +inline const char* GetVarint32Ptr(const char* p, const char* limit, + uint32* value) { + if (p < limit) { + uint32 result = *(reinterpret_cast(p)); + if ((result & 128) == 0) { + *value = result; + return p + 1; + } + } + return GetVarint32PtrFallback(p, limit, value); +} + +extern char* EncodeVarint64(char* dst, uint64 v); + +// Returns the length of the varint32 or varint64 encoding of "v" +extern int VarintLength(uint64_t v); + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_CODING_H_ diff --git a/tensorflow/core/lib/core/coding_test.cc b/tensorflow/core/lib/core/coding_test.cc new file mode 100644 index 0000000000..5e9e2c5e96 --- /dev/null +++ b/tensorflow/core/lib/core/coding_test.cc @@ -0,0 +1,168 @@ +#include "tensorflow/core/lib/core/coding.h" + +#include + +namespace tensorflow { +namespace core { + +TEST(Coding, Fixed32) { + static const int N = 100000; + + string s; + for (uint32 v = 0; v < N; v++) { + char buf[sizeof(uint32)]; + EncodeFixed32(buf, v); + s.append(buf, sizeof(buf)); + } + + const char* p = s.data(); + for (uint32 v = 0; v < N; v++) { + uint32 actual = DecodeFixed32(p); + ASSERT_EQ(v, actual); + p += sizeof(uint32); + } +} + +TEST(Coding, Fixed64) { + string s; + for (int power = 0; power <= 63; power++) { + uint64 v = static_cast(1) << power; + char buf[sizeof(uint64)]; + EncodeFixed64(buf, v - 1); + s.append(buf, sizeof(buf)); + EncodeFixed64(buf, v + 0); + s.append(buf, sizeof(buf)); + EncodeFixed64(buf, v + 1); + s.append(buf, sizeof(buf)); + } + + const char* p = s.data(); + for (int power = 0; power <= 63; power++) { + uint64 v = static_cast(1) << power; + uint64 actual; + actual = DecodeFixed64(p); + ASSERT_EQ(v - 1, actual); + p += sizeof(uint64); + + actual = DecodeFixed64(p); + ASSERT_EQ(v + 0, actual); + p += sizeof(uint64); + + actual = DecodeFixed64(p); + ASSERT_EQ(v + 1, actual); + p += sizeof(uint64); + } +} + +// Test that encoding routines generate little-endian encodings +TEST(Coding, EncodingOutput) { + char dst[8]; + EncodeFixed32(dst, 0x04030201); + ASSERT_EQ(0x01, static_cast(dst[0])); + ASSERT_EQ(0x02, static_cast(dst[1])); + ASSERT_EQ(0x03, static_cast(dst[2])); + ASSERT_EQ(0x04, static_cast(dst[3])); + + EncodeFixed64(dst, 0x0807060504030201ull); + ASSERT_EQ(0x01, static_cast(dst[0])); + ASSERT_EQ(0x02, static_cast(dst[1])); + ASSERT_EQ(0x03, static_cast(dst[2])); + ASSERT_EQ(0x04, static_cast(dst[3])); + ASSERT_EQ(0x05, static_cast(dst[4])); + ASSERT_EQ(0x06, static_cast(dst[5])); + ASSERT_EQ(0x07, static_cast(dst[6])); + ASSERT_EQ(0x08, static_cast(dst[7])); +} + +TEST(Coding, Varint32) { + string s; + for (uint32 i = 0; i < (32 * 32); i++) { + uint32 v = (i / 32) << (i % 32); + PutVarint32(&s, v); + } + + const char* p = s.data(); + const char* limit = p + s.size(); + for (uint32 i = 0; i < (32 * 32); i++) { + uint32 expected = (i / 32) << (i % 32); + uint32 actual; + p = GetVarint32Ptr(p, limit, &actual); + ASSERT_TRUE(p != NULL); + ASSERT_EQ(expected, actual); + } + ASSERT_EQ(p, s.data() + s.size()); +} + +TEST(Coding, Varint64) { + // Construct the list of values to check + std::vector values; + // Some special values + values.push_back(0); + values.push_back(100); + values.push_back(~static_cast(0)); + values.push_back(~static_cast(0) - 1); + for (uint32 k = 0; k < 64; k++) { + // Test values near powers of two + const uint64 power = 1ull << k; + values.push_back(power); + values.push_back(power - 1); + values.push_back(power + 1); + } + + string s; + for (size_t i = 0; i < values.size(); i++) { + PutVarint64(&s, values[i]); + } + + const char* p = s.data(); + const char* limit = p + s.size(); + for (size_t i = 0; i < values.size(); i++) { + ASSERT_TRUE(p < limit); + uint64 actual; + p = GetVarint64Ptr(p, limit, &actual); + ASSERT_TRUE(p != NULL); + ASSERT_EQ(values[i], actual); + } + ASSERT_EQ(p, limit); +} + +TEST(Coding, Varint32Overflow) { + uint32 result; + string input("\x81\x82\x83\x84\x85\x11"); + ASSERT_TRUE(GetVarint32Ptr(input.data(), input.data() + input.size(), + &result) == NULL); +} + +TEST(Coding, Varint32Truncation) { + uint32 large_value = (1u << 31) + 100; + string s; + PutVarint32(&s, large_value); + uint32 result; + for (size_t len = 0; len < s.size() - 1; len++) { + ASSERT_TRUE(GetVarint32Ptr(s.data(), s.data() + len, &result) == NULL); + } + ASSERT_TRUE(GetVarint32Ptr(s.data(), s.data() + s.size(), &result) != NULL); + ASSERT_EQ(large_value, result); +} + +TEST(Coding, Varint64Overflow) { + uint64 result; + string input("\x81\x82\x83\x84\x85\x81\x82\x83\x84\x85\x11"); + ASSERT_TRUE(GetVarint64Ptr(input.data(), input.data() + input.size(), + &result) == NULL); +} + +TEST(Coding, Varint64Truncation) { + uint64 large_value = (1ull << 63) + 100ull; + string s; + PutVarint64(&s, large_value); + uint64 result; + for (size_t len = 0; len < s.size() - 1; len++) { + ASSERT_TRUE(GetVarint64Ptr(s.data(), s.data() + len, &result) == NULL); + } + ASSERT_TRUE(GetVarint64Ptr(s.data(), s.data() + s.size(), &result) != NULL); + ASSERT_EQ(large_value, result); +} + +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/command_line_flags.cc b/tensorflow/core/lib/core/command_line_flags.cc new file mode 100644 index 0000000000..0f1072ffaa --- /dev/null +++ b/tensorflow/core/lib/core/command_line_flags.cc @@ -0,0 +1,94 @@ +#include "tensorflow/core/lib/core/command_line_flags.h" + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace tensorflow { +namespace { + +// Templated function to convert a string to target values. +// Return true if the conversion is successful. Otherwise, return false. +template +bool StringToValue(const string& content, T* value); + +template <> +bool StringToValue(const string& content, int* value) { + return str_util::NumericParse32(content, value); +} + +// Parse a single argument by linearly searching through the command table. +// The input format is: --argument=value. +// Return OK if the argument is used. It store the extracted value into the +// matching flag. +// Return NOT_FOUND if the argument is not recognized. +// Retrun INVALID_ARGUMENT if the command is recognized, but fails to extract +// its value. +template +Status ParseArgument(const string& argument) { + for (auto& command : + internal::CommandLineFlagRegistry::Instance()->commands) { + string prefix = strings::StrCat("--", command.name, "="); + if (tensorflow::StringPiece(argument).starts_with(prefix)) { + string content = argument.substr(prefix.length()); + if (StringToValue(content, command.value)) { + return Status::OK(); + } + return Status(error::INVALID_ARGUMENT, + strings::StrCat("Cannot parse integer in: ", argument)); + } + } + return Status(error::NOT_FOUND, + strings::StrCat("Unknown command: ", argument)); +} + +// A specialization for booleans. The input format is: +// "--argument" or "--noargument". +// Parse a single argument by linearly searching through the command table. +// Return OK if the argument is used. The value is stored in the matching flag. +// Return NOT_FOUND if the argument is not recognized. +template <> +Status ParseArgument(const string& argument) { + for (auto& command : + internal::CommandLineFlagRegistry::Instance()->commands) { + if (argument == strings::StrCat("--", command.name)) { + *command.value = true; + return Status::OK(); + } else if (argument == strings::StrCat("--no", command.name)) { + *command.value = false; + return Status::OK(); + } + } + return Status(error::NOT_FOUND, + strings::StrCat("Unknown command: ", argument)); +} +} // namespace + +Status ParseCommandLineFlags(int* argc, char* argv[]) { + int unused_argc = 1; + for (int index = 1; index < *argc; ++index) { + Status s; + // Search bool commands. + s = ParseArgument(argv[index]); + if (s.ok()) { + continue; + } + if (s.code() != error::NOT_FOUND) { + return s; + } + // Search int32 commands. + s = ParseArgument(argv[index]); + if (s.ok()) { + continue; + } + if (s.code() != error::NOT_FOUND) { + return s; + } + // Pointer swap the unused argument to the front. + std::swap(argv[unused_argc++], argv[index]); + } + *argc = unused_argc; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/command_line_flags.h b/tensorflow/core/lib/core/command_line_flags.h new file mode 100644 index 0000000000..f1a94c11f9 --- /dev/null +++ b/tensorflow/core/lib/core/command_line_flags.h @@ -0,0 +1,60 @@ +#ifndef TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ +#define TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace internal { + +template +struct CommandLineFlagRegistry { + static CommandLineFlagRegistry* Instance() { + static CommandLineFlagRegistry instance_; + return &instance_; + } + struct Command { + string name; + T* value; + string text; + }; + std::vector commands; + + private: + CommandLineFlagRegistry() {} + TF_DISALLOW_COPY_AND_ASSIGN(CommandLineFlagRegistry); +}; + +template +struct CommandLineFlagRegister { + CommandLineFlagRegister(const string& name, T* val, const string& text) { + CommandLineFlagRegistry::Instance()->commands.push_back( + {name, val, text}); + } +}; + +#define TF_DEFINE_variable(type, name, default_value, text) \ + type FLAGS_##name = default_value; \ + namespace TF_flags_internal { \ + tensorflow::internal::CommandLineFlagRegister \ + TF_flags_internal_var_##name(#name, &FLAGS_##name, text); \ + } // namespace TF_flags_internal + +} // namespace internal + +#define TF_DEFINE_int32(name, default_value, text) \ + TF_DEFINE_variable(int32, name, default_value, text); + +#define TF_DEFINE_bool(name, default_value, text) \ + TF_DEFINE_variable(bool, name, default_value, text); + +// Parse argv[1]..argv[*argc-1] to options. Remove used arguments from the argv. +// Returned the number of unused arguments in *argc. +// Return error Status if the parsing encounters errors. +// TODO(opensource): switch to a command line argument parser that can be +// shared with other tests. +Status ParseCommandLineFlags(int* argc, char* argv[]); + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ diff --git a/tensorflow/core/lib/core/error_codes.proto b/tensorflow/core/lib/core/error_codes.proto new file mode 100644 index 0000000000..6735fd8f88 --- /dev/null +++ b/tensorflow/core/lib/core/error_codes.proto @@ -0,0 +1,145 @@ +syntax = "proto3"; + +package tensorflow.error; +// option cc_enable_arenas = true; + +// The canonical error codes for TensorFlow APIs. +// +// Warnings: +// +// - Do not change any numeric assignments. +// - Changes to this list should only be made if there is a compelling +// need that can't be satisfied in another way. Such changes +// must be approved by at least two OWNERS. +// +// Sometimes multiple error codes may apply. Services should return +// the most specific error code that applies. For example, prefer +// OUT_OF_RANGE over FAILED_PRECONDITION if both codes apply. +// Similarly prefer NOT_FOUND or ALREADY_EXISTS over FAILED_PRECONDITION. +enum Code { + // Not an error; returned on success + OK = 0; + + // The operation was cancelled (typically by the caller). + CANCELLED = 1; + + // Unknown error. An example of where this error may be returned is + // if a Status value received from another address space belongs to + // an error-space that is not known in this address space. Also + // errors raised by APIs that do not return enough error information + // may be converted to this error. + UNKNOWN = 2; + + // Client specified an invalid argument. Note that this differs + // from FAILED_PRECONDITION. INVALID_ARGUMENT indicates arguments + // that are problematic regardless of the state of the system + // (e.g., a malformed file name). + INVALID_ARGUMENT = 3; + + // Deadline expired before operation could complete. For operations + // that change the state of the system, this error may be returned + // even if the operation has completed successfully. For example, a + // successful response from a server could have been delayed long + // enough for the deadline to expire. + DEADLINE_EXCEEDED = 4; + + // Some requested entity (e.g., file or directory) was not found. + // For privacy reasons, this code *may* be returned when the client + // does not have the access right to the entity. + NOT_FOUND = 5; + + // Some entity that we attempted to create (e.g., file or directory) + // already exists. + ALREADY_EXISTS = 6; + + // The caller does not have permission to execute the specified + // operation. PERMISSION_DENIED must not be used for rejections + // caused by exhausting some resource (use RESOURCE_EXHAUSTED + // instead for those errors). PERMISSION_DENIED must not be + // used if the caller can not be identified (use UNAUTHENTICATED + // instead for those errors). + PERMISSION_DENIED = 7; + + // The request does not have valid authentication credentials for the + // operation. + UNAUTHENTICATED = 16; + + // Some resource has been exhausted, perhaps a per-user quota, or + // perhaps the entire file system is out of space. + RESOURCE_EXHAUSTED = 8; + + // Operation was rejected because the system is not in a state + // required for the operation's execution. For example, directory + // to be deleted may be non-empty, an rmdir operation is applied to + // a non-directory, etc. + // + // A litmus test that may help a service implementor in deciding + // between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE: + // (a) Use UNAVAILABLE if the client can retry just the failing call. + // (b) Use ABORTED if the client should retry at a higher-level + // (e.g., restarting a read-modify-write sequence). + // (c) Use FAILED_PRECONDITION if the client should not retry until + // the system state has been explicitly fixed. E.g., if an "rmdir" + // fails because the directory is non-empty, FAILED_PRECONDITION + // should be returned since the client should not retry unless + // they have first fixed up the directory by deleting files from it. + // (d) Use FAILED_PRECONDITION if the client performs conditional + // REST Get/Update/Delete on a resource and the resource on the + // server does not match the condition. E.g., conflicting + // read-modify-write on the same resource. + FAILED_PRECONDITION = 9; + + // The operation was aborted, typically due to a concurrency issue + // like sequencer check failures, transaction aborts, etc. + // + // See litmus test above for deciding between FAILED_PRECONDITION, + // ABORTED, and UNAVAILABLE. + ABORTED = 10; + + // Operation was attempted past the valid range. E.g., seeking or + // reading past end of file. + // + // Unlike INVALID_ARGUMENT, this error indicates a problem that may + // be fixed if the system state changes. For example, a 32-bit file + // system will generate INVALID_ARGUMENT if asked to read at an + // offset that is not in the range [0,2^32-1], but it will generate + // OUT_OF_RANGE if asked to read from an offset past the current + // file size. + // + // There is a fair bit of overlap between FAILED_PRECONDITION and + // OUT_OF_RANGE. We recommend using OUT_OF_RANGE (the more specific + // error) when it applies so that callers who are iterating through + // a space can easily look for an OUT_OF_RANGE error to detect when + // they are done. + OUT_OF_RANGE = 11; + + // Operation is not implemented or not supported/enabled in this service. + UNIMPLEMENTED = 12; + + // Internal errors. Means some invariants expected by underlying + // system has been broken. If you see one of these errors, + // something is very broken. + INTERNAL = 13; + + // The service is currently unavailable. This is a most likely a + // transient condition and may be corrected by retrying with + // a backoff. + // + // See litmus test above for deciding between FAILED_PRECONDITION, + // ABORTED, and UNAVAILABLE. + UNAVAILABLE = 14; + + // Unrecoverable data loss or corruption. + DATA_LOSS = 15; + + // An extra enum entry to prevent people from writing code that + // fails to compile when a new code is added. + // + // Nobody should ever reference this enumeration entry. In particular, + // if you write C++ code that switches on this enumeration, add a default: + // case instead of a case that mentions this enumeration entry. + // + // Nobody should rely on the value (currently 20) listed here. It + // may change in the future. + DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_ = 20; +} diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h new file mode 100644 index 0000000000..b0badd8c4d --- /dev/null +++ b/tensorflow/core/lib/core/errors.h @@ -0,0 +1,131 @@ +#ifndef TENSORFLOW_LIB_CORE_ERRORS_H_ +#define TENSORFLOW_LIB_CORE_ERRORS_H_ + +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace errors { + +typedef ::tensorflow::error::Code Code; + +// Append some context to an error message. Each time we append +// context put it on a new line, since it is possible for there +// to be several layers of additional context. +template +void AppendToMessage(::tensorflow::Status* status, Args... args) { + *status = ::tensorflow::Status( + status->code(), + strings::StrCat(status->error_message(), "\n\t", args...)); +} + +// For propagating errors when calling a function. +#define TF_RETURN_IF_ERROR(expr) \ + do { \ + const ::tensorflow::Status _status = (expr); \ + if (TF_PREDICT_FALSE(!_status.ok())) return _status; \ + } while (0) + +#define TF_RETURN_WITH_CONTEXT_IF_ERROR(expr, ...) \ + do { \ + ::tensorflow::Status _status = (expr); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + ::tensorflow::errors::AppendToMessage(&_status, __VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +// Convenience functions for generating and using error status. +// Example usage: +// status.Update(errors::InvalidArgument("The ", foo, " isn't right.")); +// if (errors::IsInvalidArgument(status)) { ... } +// switch (status.code()) { case error::INVALID_ARGUMENT: ... } + +#define DECLARE_ERROR(FUNC, CONST) \ + template \ + inline ::tensorflow::Status FUNC(Args... args) { \ + return ::tensorflow::Status(::tensorflow::error::CONST, \ + strings::StrCat(args...)); \ + } \ + inline bool Is##FUNC(const ::tensorflow::Status& status) { \ + return status.code() == ::tensorflow::error::CONST; \ + } + +DECLARE_ERROR(Cancelled, CANCELLED) +DECLARE_ERROR(InvalidArgument, INVALID_ARGUMENT) +DECLARE_ERROR(NotFound, NOT_FOUND) +DECLARE_ERROR(AlreadyExists, ALREADY_EXISTS) +DECLARE_ERROR(ResourceExhausted, RESOURCE_EXHAUSTED) +DECLARE_ERROR(Unavailable, UNAVAILABLE) +DECLARE_ERROR(FailedPrecondition, FAILED_PRECONDITION) +DECLARE_ERROR(OutOfRange, OUT_OF_RANGE) +DECLARE_ERROR(Unimplemented, UNIMPLEMENTED) +DECLARE_ERROR(Internal, INTERNAL) +DECLARE_ERROR(Aborted, ABORTED) +DECLARE_ERROR(DeadlineExceeded, DEADLINE_EXCEEDED) +DECLARE_ERROR(DataLoss, DATA_LOSS) +DECLARE_ERROR(Unknown, UNKNOWN) +DECLARE_ERROR(PermissionDenied, PERMISSION_DENIED) +DECLARE_ERROR(Unauthenticated, UNAUTHENTICATED) + +#undef DECLARE_ERROR + +// The CanonicalCode() for non-errors. +using ::tensorflow::error::OK; + +// Convenience macros for asserting and handling exceptional conditions. +// Analogous to the CHECK* macros provided by logging.h. +// +// Example use: +// void Compute(OperationContext* context) { +// OP_REQUIRES(context, context->num_inputs() == 2, +// errors::InvalidArgument("FooOp requires 2 arguments")); +// ... +// Status status = SomeUncertainMethod(); +// OP_REQUIRES_OK(context, status); +// ... +// } + +#define OP_REQUIRES(CTX, EXP, STATUS) \ + if (!(EXP)) { \ + ::tensorflow::Status _s(STATUS); \ + VLOG(1) << _s; \ + (CTX)->SetStatus(_s); \ + return; \ + } + +#define OP_REQUIRES_OK(CTX, STATUS) \ + do { \ + ::tensorflow::Status _s(STATUS); \ + if (!_s.ok()) { \ + LOG(WARNING) << _s; \ + (CTX)->SetStatus(_s); \ + return; \ + } \ + } while (0) + +#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \ + if (!(EXP)) { \ + ::tensorflow::Status _s(STATUS); \ + VLOG(1) << _s; \ + (CTX)->SetStatus(_s); \ + (CALLBACK)(); \ + return; \ + } + +#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \ + do { \ + ::tensorflow::Status _s(STATUS); \ + if (!_s.ok()) { \ + LOG(WARNING) << _s; \ + (CTX)->SetStatus(_s); \ + (CALLBACK)(); \ + return; \ + } \ + } while (0) + +} // namespace errors +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_ERRORS_H_ diff --git a/tensorflow/core/lib/core/notification.h b/tensorflow/core/lib/core/notification.h new file mode 100644 index 0000000000..071e24285a --- /dev/null +++ b/tensorflow/core/lib/core/notification.h @@ -0,0 +1,42 @@ +#ifndef TENSORFLOW_UTIL_NOTIFICATION_H_ +#define TENSORFLOW_UTIL_NOTIFICATION_H_ + +#include + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class Notification { + public: + Notification() : notified_(false) {} + ~Notification() {} + + void Notify() { + mutex_lock l(mu_); + assert(!notified_); + notified_ = true; + cv_.notify_all(); + } + + bool HasBeenNotified() { + mutex_lock l(mu_); + return notified_; + } + + void WaitForNotification() { + mutex_lock l(mu_); + while (!notified_) { + cv_.wait(l); + } + } + + private: + mutex mu_; + condition_variable cv_; + bool notified_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_NOTIFICATION_H_ diff --git a/tensorflow/core/lib/core/notification_test.cc b/tensorflow/core/lib/core/notification_test.cc new file mode 100644 index 0000000000..a9e8942f05 --- /dev/null +++ b/tensorflow/core/lib/core/notification_test.cc @@ -0,0 +1,64 @@ +#include + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace { + +TEST(NotificationTest, TestSingleNotification) { + thread::ThreadPool* thread_pool = + new thread::ThreadPool(Env::Default(), "test", 1); + + int counter = 0; + Notification start; + Notification proceed; + thread_pool->Schedule([&start, &proceed, &counter] { + start.Notify(); + proceed.WaitForNotification(); + ++counter; + }); + + // Wait for the thread to start + start.WaitForNotification(); + + // The thread should be waiting for the 'proceed' notification. + EXPECT_EQ(0, counter); + + // Unblock the thread + proceed.Notify(); + + delete thread_pool; // Wait for closure to finish. + + // Verify the counter has been incremented + EXPECT_EQ(1, counter); +} + +TEST(NotificationTest, TestMultipleThreadsWaitingOnNotification) { + const int num_closures = 4; + thread::ThreadPool* thread_pool = + new thread::ThreadPool(Env::Default(), "test", num_closures); + + mutex lock; + int counter = 0; + Notification n; + + for (int i = 0; i < num_closures; ++i) { + thread_pool->Schedule([&n, &lock, &counter] { + n.WaitForNotification(); + mutex_lock l(lock); + ++counter; + }); + } + sleep(1); + + EXPECT_EQ(0, counter); + + n.Notify(); + delete thread_pool; // Wait for all closures to finish. + EXPECT_EQ(4, counter); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/raw_coding.h b/tensorflow/core/lib/core/raw_coding.h new file mode 100644 index 0000000000..1fe49b75bb --- /dev/null +++ b/tensorflow/core/lib/core/raw_coding.h @@ -0,0 +1,43 @@ +#ifndef TENSORFLOW_LIB_CORE_RAW_CODING_H_ +#define TENSORFLOW_LIB_CORE_RAW_CODING_H_ + +#include +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace core { + +// Lower-level versions of Get... that read directly from a character buffer +// without any bounds checking. + +inline uint32 DecodeFixed32(const char* ptr) { + if (port::kLittleEndian) { + // Load the raw bytes + uint32 result; + memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load + return result; + } else { + return ((static_cast(static_cast(ptr[0]))) | + (static_cast(static_cast(ptr[1])) << 8) | + (static_cast(static_cast(ptr[2])) << 16) | + (static_cast(static_cast(ptr[3])) << 24)); + } +} + +inline uint64 DecodeFixed64(const char* ptr) { + if (port::kLittleEndian) { + // Load the raw bytes + uint64 result; + memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load + return result; + } else { + uint64 lo = DecodeFixed32(ptr); + uint64 hi = DecodeFixed32(ptr + 4); + return (hi << 32) | lo; + } +} + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_RAW_CODING_H_ diff --git a/tensorflow/core/lib/core/refcount.cc b/tensorflow/core/lib/core/refcount.cc new file mode 100644 index 0000000000..3ed8c58eb8 --- /dev/null +++ b/tensorflow/core/lib/core/refcount.cc @@ -0,0 +1,35 @@ +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace core { + +RefCounted::RefCounted() : ref_(1) {} + +RefCounted::~RefCounted() { DCHECK_EQ(ref_.load(), 0); } + +void RefCounted::Ref() const { + DCHECK_GE(ref_.load(), 1); + ref_.fetch_add(1, std::memory_order_relaxed); +} + +bool RefCounted::Unref() const { + DCHECK_GT(ref_.load(), 0); + // If ref_==1, this object is owned only by the caller. Bypass a locked op + // in that case. + if (ref_.load(std::memory_order_acquire) == 1 || ref_.fetch_sub(1) == 1) { + // Make DCHECK in ~RefCounted happy + DCHECK((ref_.store(0), true)); + delete this; + return true; + } else { + return false; + } +} + +bool RefCounted::RefCountIsOne() const { + return (ref_.load(std::memory_order_acquire) == 1); +} + +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/refcount.h b/tensorflow/core/lib/core/refcount.h new file mode 100644 index 0000000000..f727750f9e --- /dev/null +++ b/tensorflow/core/lib/core/refcount.h @@ -0,0 +1,63 @@ +#ifndef TENSORFLOW_LIB_CORE_REFCOUNT_H_ +#define TENSORFLOW_LIB_CORE_REFCOUNT_H_ + +#include + +namespace tensorflow { +namespace core { + +class RefCounted { + public: + // Initial reference count is one. + RefCounted(); + + // Increments reference count by one. + void Ref() const; + + // Decrements reference count by one. If the count remains + // positive, returns false. When the count reaches zero, returns + // true and deletes this, in which case the caller must not access + // the object afterward. + bool Unref() const; + + // Return whether the reference count is one. + // If the reference count is used in the conventional way, a + // reference count of 1 implies that the current thread owns the + // reference and no other thread shares it. + // This call performs the test for a reference count of one, and + // performs the memory barrier needed for the owning thread + // to act on the object, knowing that it has exclusive access to the + // object. + bool RefCountIsOne() const; + + protected: + // Make destructor protected so that RefCounted objects cannot + // be instantiated directly. Only subclasses can be instantiated. + virtual ~RefCounted(); + + private: + mutable std::atomic_int_fast32_t ref_; + + RefCounted(const RefCounted&) = delete; + void operator=(const RefCounted&) = delete; +}; + +// Helper class to unref an object when out-of-scope. +class ScopedUnref { + public: + explicit ScopedUnref(RefCounted* o) : obj_(o) {} + ~ScopedUnref() { + if (obj_) obj_->Unref(); + } + + private: + RefCounted* obj_; + + ScopedUnref(const ScopedUnref&) = delete; + void operator=(const ScopedUnref&) = delete; +}; + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_REFCOUNT_H_ diff --git a/tensorflow/core/lib/core/refcount_test.cc b/tensorflow/core/lib/core/refcount_test.cc new file mode 100644 index 0000000000..c042be2d61 --- /dev/null +++ b/tensorflow/core/lib/core/refcount_test.cc @@ -0,0 +1,92 @@ +#include "tensorflow/core/lib/core/refcount.h" + +#include + +namespace tensorflow { +namespace core { +namespace { + +static int constructed = 0; +static int destroyed = 0; + +class MyRef : public RefCounted { + public: + MyRef() { constructed++; } + ~MyRef() override { destroyed++; } +}; + +class RefTest : public testing::Test { + public: + RefTest() { + constructed = 0; + destroyed = 0; + } +}; + +TEST_F(RefTest, New) { + MyRef* ref = new MyRef; + ASSERT_EQ(1, constructed); + ASSERT_EQ(0, destroyed); + ref->Unref(); + ASSERT_EQ(1, constructed); + ASSERT_EQ(1, destroyed); +} + +TEST_F(RefTest, RefUnref) { + MyRef* ref = new MyRef; + ASSERT_EQ(1, constructed); + ASSERT_EQ(0, destroyed); + ref->Ref(); + ASSERT_EQ(0, destroyed); + ref->Unref(); + ASSERT_EQ(0, destroyed); + ref->Unref(); + ASSERT_EQ(1, destroyed); +} + +TEST_F(RefTest, RefCountOne) { + MyRef* ref = new MyRef; + ASSERT_TRUE(ref->RefCountIsOne()); + ref->Unref(); +} + +TEST_F(RefTest, RefCountNotOne) { + MyRef* ref = new MyRef; + ref->Ref(); + ASSERT_FALSE(ref->RefCountIsOne()); + ref->Unref(); + ref->Unref(); +} + +TEST_F(RefTest, ConstRefUnref) { + const MyRef* cref = new MyRef; + ASSERT_EQ(1, constructed); + ASSERT_EQ(0, destroyed); + cref->Ref(); + ASSERT_EQ(0, destroyed); + cref->Unref(); + ASSERT_EQ(0, destroyed); + cref->Unref(); + ASSERT_EQ(1, destroyed); +} + +TEST_F(RefTest, ReturnOfUnref) { + MyRef* ref = new MyRef; + ref->Ref(); + EXPECT_FALSE(ref->Unref()); + EXPECT_TRUE(ref->Unref()); +} + +TEST_F(RefTest, ScopedUnref) { + { ScopedUnref unref(new MyRef); } + EXPECT_EQ(destroyed, 1); +} + +TEST_F(RefTest, ScopedUnref_Nullptr) { + { ScopedUnref unref(nullptr); } + EXPECT_EQ(destroyed, 0); +} + +} // namespace +} // namespace core +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/status.cc b/tensorflow/core/lib/core/status.cc new file mode 100644 index 0000000000..24ce842560 --- /dev/null +++ b/tensorflow/core/lib/core/status.cc @@ -0,0 +1,107 @@ +#include "tensorflow/core/public/status.h" +#include + +namespace tensorflow { + +Status::Status(tensorflow::error::Code code, StringPiece msg) { + assert(code != tensorflow::error::OK); + state_ = new State; + state_->code = code; + state_->msg = msg.ToString(); +} +Status::~Status() { delete state_; } + +void Status::Update(const Status& new_status) { + if (ok()) { + *this = new_status; + } +} + +void Status::SlowCopyFrom(const State* src) { + delete state_; + if (src == nullptr) { + state_ = nullptr; + } else { + state_ = new State(*src); + } +} + +const string& Status::empty_string() { + static string* empty = new string; + return *empty; +} + +string Status::ToString() const { + if (state_ == NULL) { + return "OK"; + } else { + char tmp[30]; + const char* type; + switch (code()) { + case tensorflow::error::CANCELLED: + type = "Cancelled"; + break; + case tensorflow::error::UNKNOWN: + type = "Unknown"; + break; + case tensorflow::error::INVALID_ARGUMENT: + type = "Invalid argument"; + break; + case tensorflow::error::DEADLINE_EXCEEDED: + type = "Deadline exceeded"; + break; + case tensorflow::error::NOT_FOUND: + type = "Not found"; + break; + case tensorflow::error::ALREADY_EXISTS: + type = "Already exists"; + break; + case tensorflow::error::PERMISSION_DENIED: + type = "Permission denied"; + break; + case tensorflow::error::UNAUTHENTICATED: + type = "Unauthenticated"; + break; + case tensorflow::error::RESOURCE_EXHAUSTED: + type = "Resource exhausted"; + break; + case tensorflow::error::FAILED_PRECONDITION: + type = "Failed precondition"; + break; + case tensorflow::error::ABORTED: + type = "Aborted"; + break; + case tensorflow::error::OUT_OF_RANGE: + type = "Out of range"; + break; + case tensorflow::error::UNIMPLEMENTED: + type = "Unimplemented"; + break; + case tensorflow::error::INTERNAL: + type = "Internal"; + break; + case tensorflow::error::UNAVAILABLE: + type = "Unavailable"; + break; + case tensorflow::error::DATA_LOSS: + type = "Data loss"; + break; + default: + snprintf(tmp, sizeof(tmp), "Unknown code(%d)", + static_cast(code())); + type = tmp; + break; + } + string result(type); + result += ": "; + result += state_->msg; + return result; + } +} + +std::ostream& operator<<(std::ostream& os, const Status& x) { + os << x.ToString(); + return os; +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/status_test.cc b/tensorflow/core/lib/core/status_test.cc new file mode 100644 index 0000000000..3ef6b3302a --- /dev/null +++ b/tensorflow/core/lib/core/status_test.cc @@ -0,0 +1,84 @@ +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include + +namespace tensorflow { + +TEST(Status, OK) { + EXPECT_EQ(Status::OK().code(), error::OK); + EXPECT_EQ(Status::OK().error_message(), ""); + EXPECT_OK(Status::OK()); + ASSERT_OK(Status::OK()); + EXPECT_EQ(Status::OK(), Status()); + Status s; + EXPECT_TRUE(s.ok()); +} + +TEST(DeathStatus, CheckOK) { + Status status(errors::InvalidArgument("Invalid")); + ASSERT_DEATH(TF_CHECK_OK(status), "Invalid"); +} + +TEST(Status, Set) { + Status status; + status = Status(error::CANCELLED, "Error message"); + EXPECT_EQ(status.code(), error::CANCELLED); + EXPECT_EQ(status.error_message(), "Error message"); +} + +TEST(Status, Copy) { + Status a(errors::InvalidArgument("Invalid")); + Status b(a); + ASSERT_EQ(a.ToString(), b.ToString()); +} + +TEST(Status, Assign) { + Status a(errors::InvalidArgument("Invalid")); + Status b; + b = a; + ASSERT_EQ(a.ToString(), b.ToString()); +} + +TEST(Status, Update) { + Status s; + s.Update(Status::OK()); + ASSERT_TRUE(s.ok()); + Status a(errors::InvalidArgument("Invalid")); + s.Update(a); + ASSERT_EQ(s.ToString(), a.ToString()); + Status b(errors::Internal("Internal")); + s.Update(b); + ASSERT_EQ(s.ToString(), a.ToString()); + s.Update(Status::OK()); + ASSERT_EQ(s.ToString(), a.ToString()); + ASSERT_FALSE(s.ok()); +} + +TEST(Status, EqualsOK) { ASSERT_EQ(Status::OK(), Status()); } + +TEST(Status, EqualsSame) { + Status a(errors::InvalidArgument("Invalid")); + Status b(errors::InvalidArgument("Invalid")); + ASSERT_EQ(a, b); +} + +TEST(Status, EqualsCopy) { + const Status a(errors::InvalidArgument("Invalid")); + const Status b = a; + ASSERT_EQ(a, b); +} + +TEST(Status, EqualsDifferentCode) { + const Status a(errors::InvalidArgument("message")); + const Status b(errors::Internal("message")); + ASSERT_NE(a, b); +} + +TEST(Status, EqualsDifferentMessage) { + const Status a(errors::InvalidArgument("message")); + const Status b(errors::InvalidArgument("another")); + ASSERT_NE(a, b); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/status_test_util.h b/tensorflow/core/lib/core/status_test_util.h new file mode 100644 index 0000000000..b3b4db429f --- /dev/null +++ b/tensorflow/core/lib/core/status_test_util.h @@ -0,0 +1,20 @@ +#ifndef TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_ +#define TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_ + +#include +#include "tensorflow/core/public/status.h" + +// Macros for testing the results of functions that return util::Status. + +#define EXPECT_OK(statement) EXPECT_EQ(::tensorflow::Status::OK(), (statement)) +#define ASSERT_OK(statement) ASSERT_EQ(::tensorflow::Status::OK(), (statement)) + +// There are no EXPECT_NOT_OK/ASSERT_NOT_OK macros since they would not +// provide much value (when they fail, they would just print the OK status +// which conveys no more information than EXPECT_FALSE(status.ok()); +// If you want to check for particular errors, better alternatives are: +// EXPECT_EQ(::util::Status(...expected error...), status.StripMessage()); +// EXPECT_THAT(status.ToString(), HasSubstr("expected error")); +// Also, see testing/lib/util/status_util.h. + +#endif // TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_ diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc new file mode 100644 index 0000000000..57c5139f47 --- /dev/null +++ b/tensorflow/core/lib/core/stringpiece.cc @@ -0,0 +1,57 @@ +#include "tensorflow/core/lib/core/stringpiece.h" + +#include +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { + +size_t StringPiece::Hasher::operator()(StringPiece s) const { + return Hash64(s.data(), s.size()); +} + +std::ostream& operator<<(std::ostream& o, StringPiece piece) { + o.write(piece.data(), piece.size()); + return o; +} + +bool StringPiece::contains(StringPiece s) const { + return memmem(data_, size_, s.data_, s.size_) != nullptr; +} + +size_t StringPiece::find(char c, size_t pos) const { + if (pos >= size_) { + return npos; + } + const char* result = + reinterpret_cast(memchr(data_ + pos, c, size_ - pos)); + return result != NULL ? result - data_ : npos; +} + +// Search range is [0..pos] inclusive. If pos == npos, search everything. +size_t StringPiece::rfind(char c, size_t pos) const { + if (size_ == 0) return npos; + for (const char* p = data_ + std::min(pos, size_ - 1); p >= data_; p--) { + if (*p == c) { + return p - data_; + } + } + return npos; +} + +bool StringPiece::Consume(StringPiece x) { + if (starts_with(x)) { + remove_prefix(x.size_); + return true; + } + return false; +} + +StringPiece StringPiece::substr(size_t pos, size_t n) const { + if (pos > size_) pos = size_; + if (n > size_ - pos) n = size_ - pos; + return StringPiece(data_ + pos, n); +} + +const StringPiece::size_type StringPiece::npos = size_type(-1); + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h new file mode 100644 index 0000000000..17d4b294e9 --- /dev/null +++ b/tensorflow/core/lib/core/stringpiece.h @@ -0,0 +1,159 @@ +// StringPiece is a simple structure containing a pointer into some external +// storage and a size. The user of a StringPiece must ensure that the slice +// is not used after the corresponding external storage has been +// deallocated. +// +// Multiple threads can invoke const methods on a StringPiece without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same StringPiece must use +// external synchronization. + +#ifndef TENSORFLOW_LIB_CORE_STRINGPIECE_H_ +#define TENSORFLOW_LIB_CORE_STRINGPIECE_H_ + +#include +#include +#include +#include +#include +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class StringPiece { + public: + typedef size_t size_type; + + // Create an empty slice. + StringPiece() : data_(""), size_(0) {} + + // Create a slice that refers to d[0,n-1]. + StringPiece(const char* d, size_t n) : data_(d), size_(n) {} + + // Create a slice that refers to the contents of "s" + StringPiece(const string& s) : data_(s.data()), size_(s.size()) {} + + // Create a slice that refers to s[0,strlen(s)-1] + StringPiece(const char* s) : data_(s), size_(strlen(s)) {} + + void set(const void* data, size_t len) { + data_ = reinterpret_cast(data); + size_ = len; + } + + // Return a pointer to the beginning of the referenced data + const char* data() const { return data_; } + + // Return the length (in bytes) of the referenced data + size_t size() const { return size_; } + + // Return true iff the length of the referenced data is zero + bool empty() const { return size_ == 0; } + + typedef const char* const_iterator; + typedef const char* iterator; + iterator begin() const { return data_; } + iterator end() const { return data_ + size_; } + + static const size_t npos; + + // Return the ith byte in the referenced data. + // REQUIRES: n < size() + char operator[](size_t n) const { + assert(n < size()); + return data_[n]; + } + + // Change this slice to refer to an empty array + void clear() { + data_ = ""; + size_ = 0; + } + + // Drop the first "n" bytes from this slice. + void remove_prefix(size_t n) { + assert(n <= size()); + data_ += n; + size_ -= n; + } + + void remove_suffix(size_t n) { + assert(size_ >= n); + size_ -= n; + } + + size_t find(char c, size_t pos = 0) const; + size_t rfind(char c, size_t pos = npos) const; + bool contains(StringPiece s) const; + + // Checks whether StringPiece starts with x and if so advances the beginning + // of it to past the match. It's basically a shortcut for starts_with + // followed by remove_prefix. + bool Consume(StringPiece x); + + StringPiece substr(size_t pos, size_t n = npos) const; + + struct Hasher { + size_t operator()(StringPiece arg) const; + }; + + // Return a string that contains the copy of the referenced data. + std::string ToString() const { return std::string(data_, size_); } + + // Three-way comparison. Returns value: + // < 0 iff "*this" < "b", + // == 0 iff "*this" == "b", + // > 0 iff "*this" > "b" + int compare(StringPiece b) const; + + // Return true iff "x" is a prefix of "*this" + bool starts_with(StringPiece x) const { + return ((size_ >= x.size_) && (memcmp(data_, x.data_, x.size_) == 0)); + } + // Return true iff "x" is a suffix of "*this" + bool ends_with(StringPiece x) const { + return ((size_ >= x.size_) && + (memcmp(data_ + (size_ - x.size_), x.data_, x.size_) == 0)); + } + + private: + const char* data_; + size_t size_; + + // Intentionally copyable +}; + +inline bool operator==(StringPiece x, StringPiece y) { + return ((x.size() == y.size()) && + (memcmp(x.data(), y.data(), x.size()) == 0)); +} + +inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); } + +inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; } +inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; } +inline bool operator<=(StringPiece x, StringPiece y) { + return x.compare(y) <= 0; +} +inline bool operator>=(StringPiece x, StringPiece y) { + return x.compare(y) >= 0; +} + +inline int StringPiece::compare(StringPiece b) const { + const size_t min_len = (size_ < b.size_) ? size_ : b.size_; + int r = memcmp(data_, b.data_, min_len); + if (r == 0) { + if (size_ < b.size_) + r = -1; + else if (size_ > b.size_) + r = +1; + } + return r; +} + +// allow StringPiece to be logged +extern std::ostream& operator<<(std::ostream& o, tensorflow::StringPiece piece); + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_STRINGPIECE_H_ diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc new file mode 100644 index 0000000000..e9b84d3102 --- /dev/null +++ b/tensorflow/core/lib/core/threadpool.cc @@ -0,0 +1,108 @@ +#include "tensorflow/core/lib/core/threadpool.h" + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" + +namespace tensorflow { +namespace thread { + +struct ThreadPool::Waiter { + condition_variable cv; + bool ready; +}; + +ThreadPool::ThreadPool(Env* env, const string& name, int num_threads) + : ThreadPool(env, ThreadOptions(), name, num_threads) {} + +ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, + const string& name, int num_threads) + : name_(name) { + CHECK_GE(num_threads, 1); + string name_prefix = "tf_" + name_; + for (int i = 0; i < num_threads; i++) { + threads_.push_back(env->StartThread(thread_options, name_prefix, + [this]() { WorkerLoop(); })); + } +} + +ThreadPool::~ThreadPool() { + { + // Wait for all work to get done. + mutex_lock l(mu_); + + // Inform every thread to exit. + for (size_t i = 0; i < threads_.size(); ++i) { + pending_.push_back({nullptr, 0}); + } + + // Wakeup all waiters. + for (auto w : waiters_) { + w->ready = true; + w->cv.notify_one(); + } + } + + // Wait for threads to finish. + for (auto t : threads_) { + delete t; + } +} + +bool ThreadPool::HasPendingClosures() const { + mutex_lock l(mu_); + return pending_.size() != 0; +} + +void ThreadPool::Schedule(std::function fn) { + CHECK(fn != nullptr); + uint64 id = 0; + if (port::Tracing::IsActive()) { + id = port::Tracing::UniqueId(); + port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure, + id); + } + + mutex_lock l(mu_); + pending_.push_back({fn, id}); + if (!waiters_.empty()) { + Waiter* w = waiters_.back(); + waiters_.pop_back(); + w->ready = true; + w->cv.notify_one(); + } +} + +void ThreadPool::WorkerLoop() { + port::Tracing::RegisterCurrentThread(name_.c_str()); + mutex_lock l(mu_); + Waiter w; + while (true) { + while (pending_.empty()) { + // Wait for work to be assigned to me + w.ready = false; + waiters_.push_back(&w); + while (!w.ready) { + w.cv.wait(l); + } + } + // Pick up pending work + Item item = pending_.front(); + pending_.pop_front(); + if (item.fn == nullptr) { + break; + } + mu_.unlock(); + if (item.id != 0) { + port::Tracing::ScopedActivity region( + port::Tracing::EventCategory::kRunClosure, item.id); + item.fn(); + } else { + item.fn(); + } + mu_.lock(); + } +} + +} // namespace thread +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h new file mode 100644 index 0000000000..5cf780fa86 --- /dev/null +++ b/tensorflow/core/lib/core/threadpool.h @@ -0,0 +1,59 @@ +#ifndef TENSORFLOW_LIB_CORE_THREADPOOL_H_ +#define TENSORFLOW_LIB_CORE_THREADPOOL_H_ + +#include +#include +#include +#include +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace thread { + +class ThreadPool { + public: + // Construct a pool that contains "num_threads" threads with specified "name". + // env->StartThread() is used to create individual threads. + // + // REQUIRES: num_threads > 0 + ThreadPool(Env* env, const string& name, int num_threads); + + // Construct a pool that contains "num_threads" threads with specified "name". + // env->StartThread() is used to create individual threads. + // + // REQUIRES: num_threads > 0 + ThreadPool(Env* env, const ThreadOptions& thread_options, const string& name, + int num_threads); + + // Wait until all scheduled work has finished and then destroy the + // set of threads. + virtual ~ThreadPool(); + + // Schedule fn() for execution in the pool of threads. + virtual void Schedule(std::function fn); + + virtual bool HasPendingClosures() const; + + private: + struct Waiter; + struct Item { + std::function fn; + uint64 id; + }; + + void WorkerLoop(); + + const string name_; + mutable mutex mu_; + std::vector threads_; // All threads + std::vector waiters_; // Stack of waiting threads. + std::deque pending_; // Queue of pending work + + TF_DISALLOW_COPY_AND_ASSIGN(ThreadPool); +}; + +} // namespace thread +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_CORE_THREADPOOL_H_ diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc new file mode 100644 index 0000000000..f4909c445c --- /dev/null +++ b/tensorflow/core/lib/core/threadpool_test.cc @@ -0,0 +1,93 @@ +#include "tensorflow/core/lib/core/threadpool.h" + +#include + +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/env.h" +#include + +namespace tensorflow { +namespace thread { + +static const int kNumThreads = 30; + +TEST(ThreadPool, Empty) { + for (int num_threads = 1; num_threads < kNumThreads; num_threads++) { + fprintf(stderr, "Testing with %d threads\n", num_threads); + ThreadPool pool(Env::Default(), "test", num_threads); + } +} + +TEST(ThreadPool, DoWork) { + for (int num_threads = 1; num_threads < kNumThreads; num_threads++) { + fprintf(stderr, "Testing with %d threads\n", num_threads); + const int kWorkItems = 15; + bool work[kWorkItems]; + for (int i = 0; i < kWorkItems; i++) { + work[i] = false; + } + { + ThreadPool pool(Env::Default(), "test", num_threads); + for (int i = 0; i < kWorkItems; i++) { + pool.Schedule([&work, i]() { + ASSERT_FALSE(work[i]); + work[i] = true; + }); + } + } + for (int i = 0; i < kWorkItems; i++) { + ASSERT_TRUE(work[i]); + } + } +} + +static void BM_Sequential(int iters) { + ThreadPool pool(Env::Default(), "test", kNumThreads); + // Decrement count sequentially until 0. + int count = iters; + mutex done_lock; + condition_variable done; + bool done_flag = false; + std::function work = [&pool, &count, &done_lock, &done, &done_flag, + &work]() { + if (count--) { + pool.Schedule(work); + } else { + mutex_lock l(done_lock); + done_flag = true; + done.notify_all(); + } + }; + work(); + mutex_lock l(done_lock); + if (!done_flag) { + done.wait(l); + } +} +BENCHMARK(BM_Sequential); + +static void BM_Parallel(int iters) { + ThreadPool pool(Env::Default(), "test", kNumThreads); + // Decrement count concurrently until 0. + std::atomic_int_fast32_t count(iters); + mutex done_lock; + condition_variable done; + bool done_flag = false; + for (int i = 0; i < iters; ++i) { + pool.Schedule([&count, &done_lock, &done, &done_flag]() { + if (count.fetch_sub(1) == 1) { + mutex_lock l(done_lock); + done_flag = true; + done.notify_all(); + } + }); + } + mutex_lock l(done_lock); + if (!done_flag) { + done.wait(l); + } +} +BENCHMARK(BM_Parallel); + +} // namespace thread +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/array_slice.h b/tensorflow/core/lib/gtl/array_slice.h new file mode 100644 index 0000000000..813fb126e3 --- /dev/null +++ b/tensorflow/core/lib/gtl/array_slice.h @@ -0,0 +1,299 @@ +// An ArraySlice represents an immutable array of elements of type +// T. It has a length "length", and a base pointer "ptr", and the +// array it represents contains the elements "ptr[0] .. ptr[len-1]". +// The backing store for the array is *not* owned by the ArraySlice +// object, and clients must arrange for the backing store to remain +// live while the ArraySlice object is in use. +// +// An ArraySlice is somewhat analogous to a StringPiece, but for +// array elements of type T. +// +// Implicit conversion operations are provided from types such as +// std::vector and util::gtl::InlinedVector. Note that ArraySlice +// objects constructed from types in this way may be invalidated by +// any operations that mutate the underlying vector. +// +// One common use for ArraySlice is when passing arguments to a +// routine where you want to be able to accept a variety of array +// types (e.g. a vector, a util::gtl::InlinedVector, a C-style array, +// etc.). The usual approach here is to have the client explicitly +// pass in a pointer and a length, as in: +// +// void MyRoutine(const int* elems, int N) { +// for (int i = 0; i < N; i++) { .. do something with elems[i] .. } +// } +// +// Unfortunately, this leads to ugly and error-prone code at the call site: +// +// std::vector my_vector; +// MyRoutine(vector_as_array(&my_vector), my_vector.size()); +// +// util::gtl::InlinedVector my_inline_vector; +// MyRoutine(my_inline_vector.array(), my_inline_vector.size()); +// +// int my_array[10]; +// MyRoutine(my_array, 10); +// +// Instead, you can use an ArraySlice as the argument to the routine: +// +// void MyRoutine(ArraySlice a) { +// for (int i = 0; i < a.size(); i++) { .. do something with a[i] .. } +// } +// +// This makes the call sites cleaner, for the most part: +// +// std::vector my_vector; +// MyRoutine(my_vector); +// +// util::gtl::InlinedVector my_inline_vector; +// MyRoutine(my_inline_vector); +// +// int my_array[10]; +// MyRoutine(my_array); +// +// int* my_array = new int[10]; +// MyRoutine(gtl::ArraySlice(my_array, 10)); +// +// MutableArraySlice represents a mutable array of elements, and, like +// ArraySlice, does not own the backing store. The implicit constructors it +// provides allow functions not to worry about whether their mutable arguments +// refer to vectors, arrays, proto2::RepeatedFields, etc.: +// +// void MyMutatingRoutine(MutableArraySlice a) { +// for (int i = 0; i < a.size(); i++) { .. mutate a[i] .. } +// } +// +// std::vector my_vector; +// MyMutatingRoutine(&my_vector); +// +// int my_array[10]; +// MyMutatingRoutine(my_array); +// +// int* my_array = new int[10]; +// MyMutatingRoutine(gtl::MutableArraySlice(my_array, 10)); +// +// MyProto my_proto; +// for (int i = 0; i < 10; ++i) { my_proto.add_value(i); } +// MyMutatingRoutine(my_proto.mutable_value()); + +#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_ +#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_ + +#include +#include +#include + +#include "tensorflow/core/lib/gtl/array_slice_internal.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { +namespace gtl { + +template +class ArraySlice { + private: + typedef array_slice_internal::ArraySliceImpl Impl; + + public: + typedef T value_type; + typedef typename Impl::pointer pointer; + typedef typename Impl::const_pointer const_pointer; + typedef typename Impl::reference reference; + typedef typename Impl::const_reference const_reference; + typedef typename Impl::iterator iterator; + typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::reverse_iterator reverse_iterator; + typedef typename Impl::const_reverse_iterator const_reverse_iterator; + typedef typename Impl::size_type size_type; + typedef typename Impl::difference_type difference_type; + + static const size_type npos = Impl::npos; + + ArraySlice() : impl_(nullptr, 0) {} + ArraySlice(const_pointer array, size_type length) : impl_(array, length) {} + + // Implicit conversion constructors + ArraySlice(const std::vector& v) // NOLINT(runtime/explicit) + : impl_(v.data(), v.size()) {} + + template + ArraySlice(const value_type (&a)[N]) // NOLINT(runtime/explicit) + : impl_(a, N) {} + + template + ArraySlice(const InlinedVector& v) // NOLINT(runtime/explicit) + : impl_(v.array(), v.size()) {} + + // The constructor for any class supplying 'data() const' that returns either + // const T* or a less const-qualified version of it, and 'some_integral_type + // size() const'. proto2::RepeatedField, string and (since C++11) + // std::vector and std::array are examples of this. See + // array_slice_internal.h for details. + template > + ArraySlice(const V& v) // NOLINT(runtime/explicit) + : impl_(v) {} + + // Implicitly constructs an ArraySlice from an initializer list. This makes it + // possible to pass a brace-enclosed initializer list to a function expecting + // an ArraySlice: + // void Process(ArraySlice x); + // Process({1, 2, 3}); + // The data referenced by the initializer_list must outlive this + // ArraySlice. For example, "ArraySlice s={1,2};" and "return + // ArraySlice({3,4});" are errors, as the resulting ArraySlice may + // reference data that is no longer valid. + ArraySlice(std::initializer_list v) // NOLINT(runtime/explicit) + : impl_(v.begin(), v.size()) {} + + // Substring of another ArraySlice. + // pos must be non-negative and <= x.length(). + // len must be non-negative and will be pinned to at most x.length() - pos. + // If len==npos, the substring continues till the end of x. + ArraySlice(const ArraySlice& x, size_type pos, size_type len) + : impl_(x.impl_, pos, len) {} + + const_pointer data() const { return impl_.data(); } + size_type size() const { return impl_.size(); } + size_type length() const { return size(); } + bool empty() const { return size() == 0; } + + void clear() { impl_.clear(); } + + const_reference operator[](size_type i) const { return impl_[i]; } + const_reference at(size_type i) const { return impl_.at(i); } + const_reference front() const { return impl_.front(); } + const_reference back() const { return impl_.back(); } + + const_iterator begin() const { return impl_.begin(); } + const_iterator end() const { return impl_.end(); } + const_reverse_iterator rbegin() const { return impl_.rbegin(); } + const_reverse_iterator rend() const { return impl_.rend(); } + + void remove_prefix(size_type n) { impl_.remove_prefix(n); } + void remove_suffix(size_type n) { impl_.remove_suffix(n); } + void pop_back() { remove_suffix(1); } + void pop_front() { remove_prefix(1); } + + // These relational operators have the same semantics as the + // std::vector relational operators: they do deep (elementwise) + // comparisons. Array slices are equal iff their size is the same + // and all their elements are equal. + bool operator==(ArraySlice other) const { return impl_ == other.impl_; } + bool operator!=(ArraySlice other) const { return impl_ != other.impl_; } + + private: + Impl impl_; +}; + +// Mutable version of ArraySlice, which allows the clients to mutate the +// underlying data. It is implicitly convertible to ArraySlice since it provides +// the data() and size() methods with correct signatures. When a +// MutableArraySlice is created from a pointer to a container (as opposed to raw +// memory pointer), the pointer must not be null. +// +// A note on const-ness: "mutable" here refers to the mutability of the +// underlying data, not of the slice itself. It is perfectly reasonable to have +// a variable of type "const MutableArraySlice"; this means that the bounds +// of the view on the array cannot be changed, but the underlying data in the +// array still may be modified. This is akin to a "T* const" pointer, as opposed +// to a "const T*" pointer (corresponding to a non-const ArraySlice). +template +class MutableArraySlice { + private: + typedef array_slice_internal::MutableArraySliceImpl Impl; + + public: + typedef T value_type; + typedef typename Impl::pointer pointer; + typedef typename Impl::const_pointer const_pointer; + typedef typename Impl::reference reference; + typedef typename Impl::const_reference const_reference; + typedef typename Impl::iterator iterator; + typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::reverse_iterator reverse_iterator; + typedef typename Impl::const_reverse_iterator const_reverse_iterator; + typedef typename Impl::size_type size_type; + typedef typename Impl::difference_type difference_type; + + static const size_type npos = Impl::npos; + + MutableArraySlice() : impl_(nullptr, 0) {} + MutableArraySlice(pointer array, size_type length) : impl_(array, length) {} + + // Implicit conversion constructors + MutableArraySlice(std::vector* v) // NOLINT(runtime/explicit) + : impl_(v->data(), v->size()) {} + + template + MutableArraySlice(value_type (&a)[N]) // NOLINT(runtime/explicit) + : impl_(a, N) {} + + template + MutableArraySlice( + InlinedVector* v) // NOLINT(runtime/explicit) + : impl_(v->mutable_array(), v->size()) {} + + // The constructor for any class supplying 'T* data()' or 'T* mutable_data()' + // (the former is called if both exist), and 'some_integral_type size() + // const'. proto2::RepeatedField is an example of this. Also supports string + // arguments, when T==char. The appropriate ctor is selected using SFINAE. See + // array_slice_internal.h for details. + template > + MutableArraySlice(V* v) // NOLINT(runtime/explicit) + : impl_(v) {} + + // Substring of another MutableArraySlice. + // pos must be non-negative and <= x.length(). + // len must be non-negative and will be pinned to at most x.length() - pos. + // If len==npos, the substring continues till the end of x. + MutableArraySlice(const MutableArraySlice& x, size_type pos, size_type len) + : impl_(x.impl_, pos, len) {} + + // Accessors. + pointer data() const { return impl_.data(); } + size_type size() const { return impl_.size(); } + size_type length() const { return size(); } + bool empty() const { return size() == 0; } + + void clear() { impl_.clear(); } + + reference operator[](size_type i) const { return impl_[i]; } + reference at(size_type i) const { return impl_.at(i); } + reference front() const { return impl_.front(); } + reference back() const { return impl_.back(); } + + iterator begin() const { return impl_.begin(); } + iterator end() const { return impl_.end(); } + reverse_iterator rbegin() const { return impl_.rbegin(); } + reverse_iterator rend() const { return impl_.rend(); } + + void remove_prefix(size_type n) { impl_.remove_prefix(n); } + void remove_suffix(size_type n) { impl_.remove_suffix(n); } + void pop_back() { remove_suffix(1); } + void pop_front() { remove_prefix(1); } + + bool operator==(ArraySlice other) const { + return ArraySlice(*this) == other; + } + bool operator!=(ArraySlice other) const { + return ArraySlice(*this) != other; + } + + // DEPRECATED(jacobsa): Please use data() instead. + pointer mutable_data() const { return impl_.data(); } + + private: + Impl impl_; +}; + +template +const typename ArraySlice::size_type ArraySlice::npos; +template +const typename MutableArraySlice::size_type MutableArraySlice::npos; + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_ diff --git a/tensorflow/core/lib/gtl/array_slice_internal.h b/tensorflow/core/lib/gtl/array_slice_internal.h new file mode 100644 index 0000000000..080f0a38d8 --- /dev/null +++ b/tensorflow/core/lib/gtl/array_slice_internal.h @@ -0,0 +1,253 @@ +// NOT FOR INCLUSION BY CLIENT CODE. This file is only to be included by +// array_slice.h. + +// Helper functions and templates for ArraySlice. + +#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_ +#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace gtl { +namespace array_slice_internal { + +// Template logic for generic constructors. + +// Wrappers whose Get() delegates to the appropriate method of a container, and +// is defined when this method exists. Delegates to the const method if C is a +// const type. +struct Data { + template + static decltype(std::declval().data()) Get(C* v) { + return v->data(); + } +}; + +struct MutableData { + template + static decltype(std::declval().mutable_data()) Get(C* v) { + return v->mutable_data(); + } +}; + +struct Size { + template + static decltype(std::declval().size()) Get(C* v) { + return v->size(); + } +}; + +struct MutableStringData { + // Defined only for string. + static char* Get(string* v) { return v->empty() ? nullptr : &*v->begin(); } +}; + +// Checks whether M::Get(C*) is defined and has a return type R such that +// Checker::valid()==true. +template +struct HasGetHelper : public M { + private: + struct None {}; + // M::Get is selected when it is viable. Get(...) is selected otherwise. + using M::Get; + static None Get(...); + + public: + static constexpr bool HasGet() { + using Result = decltype(Get(std::declval())); + return !std::is_same() && Checker::template valid(); + } +}; + +// Defines HasGet() for a particular method, container, and checker. If +// HasGet()==true, provides Get() that delegates to the method. +template ::HasGet()> +struct Wrapper { + static constexpr bool HasGet() { return false; } +}; + +template +struct Wrapper { + static constexpr bool HasGet() { return true; } + static decltype(M::Get(std::declval())) Get(C* v) { return M::Get(v); } +}; + +// Type checker for a method returning an integral value. +struct SizeChecker { + template + static constexpr bool valid() { + return std::is_integral::value; + } +}; + +// Type checker for a method returning either a pointer to T or a less const +// version of that. +template +struct DataChecker { + // We want to enable conversion from std::vector to ArraySlice + // but + // disable conversion from std::vector to ArraySlice. Here we + // use + // the fact that U** is convertible to Q* const* if and only if Q is the same + // type or a more cv-qualified version of U. + template + static constexpr bool valid() { + return std::is_convertible::value; + } +}; + +// Aliases to A if A::HasGet()==true, or to B otherwise. +template +using FirstWithGet = typename std::conditional::type; + +// Wraps C::data() const, returning a pointer to const data. +template +using ContainerData = Wrapper, const C>; + +// Wraps a method returning a pointer to mutable data. Prefers data() over +// mutable_data(), and handles strings when T==char. If data() returns a pointer +// to mutable data, it is most likely overloaded, but may also be a single +// method 'T* C::data() const' in a non-STL-compliant container. +template +using ContainerMutableData = + FirstWithGet, C>, + FirstWithGet, C>, + Wrapper, C>>>; + +// Wraps C::size() const. +template +using ContainerSize = Wrapper; + +// Implementation class for ArraySlice and MutableArraySlice. In the case of +// ArraySlice, T will be a const type; for MutableArraySlice, T will be a +// mutable type. +template +class ArraySliceImplBase { + public: + typedef T* pointer; + typedef const T* const_pointer; + typedef T& reference; + typedef const T& const_reference; + typedef pointer iterator; + typedef const_pointer const_iterator; + typedef std::reverse_iterator reverse_iterator; + typedef std::reverse_iterator const_reverse_iterator; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + static const size_type npos = -1; + + ArraySliceImplBase(pointer array, size_type length) + : ptr_(array), length_(length) {} + + // Substring of another ArraySlice. + // pos must be non-negative and <= x.length(). + // len must be non-negative and will be pinned to at most x.length() - pos. + ArraySliceImplBase(const ArraySliceImplBase& x, size_type pos, size_type len) + : ptr_(x.ptr_ + pos), length_(std::min(x.length_ - pos, len)) {} + + // Some of the const methods below return pointers and references to mutable + // data. This is only the case in this internal class; ArraySlice and + // MutableArraySlice provide deep-constness. + + pointer data() const { return ptr_; } + size_type size() const { return length_; } + + void clear() { + ptr_ = nullptr; + length_ = 0; + } + + reference operator[](size_type i) const { return ptr_[i]; } + reference at(size_type i) const { + DCHECK_LT(i, length_); + return ptr_[i]; + } + reference front() const { + DCHECK_GT(length_, 0); + return ptr_[0]; + } + reference back() const { + DCHECK_GT(length_, 0); + return ptr_[length_ - 1]; + } + + void remove_prefix(size_type n) { + DCHECK_GE(length_, n); + ptr_ += n; + length_ -= n; + } + void remove_suffix(size_type n) { + DCHECK_GE(length_, n); + length_ -= n; + } + + iterator begin() const { return ptr_; } + iterator end() const { return ptr_ + length_; } + reverse_iterator rbegin() const { return reverse_iterator(end()); } + reverse_iterator rend() const { return reverse_iterator(begin()); } + + bool operator==(const ArraySliceImplBase& other) const { + if (size() != other.size()) return false; + if (data() == other.data()) return true; + return std::equal(data(), data() + size(), other.data()); + } + bool operator!=(const ArraySliceImplBase& other) const { + return !(*this == other); + } + + private: + pointer ptr_; + size_type length_; +}; + +template +class ArraySliceImpl : public ArraySliceImplBase { + public: + using ArraySliceImplBase::ArraySliceImplBase; + + // Defined iff the data and size accessors for the container C have been + // defined. + template + using EnableIfConvertibleFrom = + typename std::enable_if::HasGet() && + ContainerSize::HasGet()>::type; + + // Constructs from a container when EnableIfConvertibleFrom is + // defined. std::addressof handles types with overloaded operator&. + template + explicit ArraySliceImpl(const C& v) + : ArraySliceImplBase(ContainerData::Get(std::addressof(v)), + ContainerSize::Get(std::addressof(v))) {} +}; + +template +class MutableArraySliceImpl : public ArraySliceImplBase { + public: + using ArraySliceImplBase::ArraySliceImplBase; + + template + using EnableIfConvertibleFrom = + typename std::enable_if::HasGet() && + ContainerSize::HasGet()>::type; + + template + explicit MutableArraySliceImpl(C* v) + : ArraySliceImplBase(ContainerMutableData::Get(v), + ContainerSize::Get(v)) {} +}; + +} // namespace array_slice_internal +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_ diff --git a/tensorflow/core/lib/gtl/array_slice_test.cc b/tensorflow/core/lib/gtl/array_slice_test.cc new file mode 100644 index 0000000000..33ee8fc8dd --- /dev/null +++ b/tensorflow/core/lib/gtl/array_slice_test.cc @@ -0,0 +1,646 @@ +#include "tensorflow/core/lib/gtl/array_slice.h" + +#include +#include +#include +#include + +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/port.h" +#include + +namespace tensorflow { +namespace gtl { +namespace { + +typedef ArraySlice IntSlice; +typedef ArraySlice CharSlice; +typedef MutableArraySlice MutableIntSlice; +typedef MutableArraySlice MutableCharSlice; +typedef std::vector IntVec; + +// Append 0..len-1 to *v +template +static void Fill(Vector* v, int len, int offset = 0) { + for (int i = 0; i < len; i++) { + v->push_back(i + offset); + } +} + +static void TestHelper(const IntSlice& vorig, const IntVec& vec) { + IntSlice other; // To test the assignment return value. + IntSlice v = other = vorig; + const int len = vec.size(); + EXPECT_EQ(v.size(), vec.size()); + + for (int i = 0; i < len; i++) { + EXPECT_EQ(v[i], vec[i]); + EXPECT_EQ(v.at(i), vec[i]); + } + EXPECT_EQ(v.begin(), gtl::vector_as_array(&vec)); + + int counter = 0; + for (IntSlice::iterator it = v.begin(); it != v.end(); ++it) { + EXPECT_EQ(counter, *it); + counter++; + } + EXPECT_EQ(counter, len); + + counter = 0; + for (IntSlice::const_iterator it = v.begin(); it != v.end(); ++it) { + EXPECT_EQ(counter, *it); + counter++; + } + EXPECT_EQ(counter, len); + + if (len > 0) { + EXPECT_EQ(0, v.front()); + EXPECT_EQ(len - 1, v.back()); + v.pop_back(); + EXPECT_EQ(len - 1, v.size()); + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(i, v[i]); + } + if (len > 1) { + v.pop_front(); + EXPECT_EQ(len - 2, v.size()); + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(i + 1, v[i]); + } + } + } +} + +// The element access test that is applicable both when MutableArraySlice is +// const and when it's not. +template +void MutableTestHelperTemplated(V v, int* ptr, const int len) { + CHECK_EQ(v.size(), len); + + for (int i = 0; i < len; i++) { + EXPECT_EQ(ptr + i, &v[i]); + EXPECT_EQ(ptr + i, &v.at(i)); + } + EXPECT_EQ(ptr, v.begin()); + EXPECT_EQ(ptr + len, v.end()); + EXPECT_EQ(ptr, v.data()); + + int counter = 0; + for (MutableIntSlice::const_iterator it = v.begin(); it != v.end(); ++it) { + EXPECT_EQ(ptr + counter, &*it); + counter++; + } + EXPECT_EQ(counter, len); + + EXPECT_EQ(len, std::distance(v.rbegin(), v.rend())); + + if (len > 0) { + EXPECT_EQ(ptr, &v.front()); + EXPECT_EQ(ptr + len - 1, &v.back()); + EXPECT_EQ(ptr + len - 1, &*v.rbegin()); + EXPECT_EQ(ptr, &*(v.rend() - 1)); + } +} + +static void MutableTestHelper(const MutableIntSlice& vorig, int* ptr, + const int len) { + // Test the data accessors both when the MutableArraySlice is declared const, + // and when it is not. + MutableTestHelperTemplated(vorig, ptr, len); + MutableTestHelperTemplated(vorig, ptr, len); + + MutableIntSlice other; // To test the assignment return value. + MutableIntSlice v = other = vorig; + EXPECT_EQ(ptr, v.mutable_data()); + + int counter = 0; + for (MutableIntSlice::iterator it = v.begin(); it != v.end(); ++it) { + EXPECT_EQ(ptr + counter, &*it); + counter++; + } + EXPECT_EQ(counter, len); + + if (len > 0) { + // Test that elements are assignable. + v[0] = 1; + v.front() = 2; + v.back() = 5; + *v.mutable_data() = 4; + std::fill(v.begin(), v.end(), 5); + std::fill(v.rbegin(), v.rend(), 6); + // Test size-changing methods. + v.pop_back(); + EXPECT_EQ(len - 1, v.size()); + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(ptr + i, &v[i]); + } + if (len > 1) { + v.pop_front(); + EXPECT_EQ(len - 2, v.size()); + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(ptr + i + 1, &v[i]); + } + } + } +} + +template +static void TestImplicitConversion(const IntSlice& v, const Vector& vec) { + EXPECT_EQ(v.size(), vec.size()); + for (size_t i = 0; i < v.size(); i++) { + EXPECT_EQ(v[i], vec[i]); + } +} + +template +static void TestImplicitConversion(const CharSlice& v, const Vector& vec) { + TestImplicitConversion(IntVec(v.begin(), v.end()), vec); +} + +static void TestImplicitConversion(const MutableIntSlice& v, const int* data, + int size) { + EXPECT_EQ(size, v.size()); + for (size_t i = 0; i < v.size(); i++) { + EXPECT_EQ(data + i, &v[i]); + } +} + +static void TestImplicitConversion(const MutableCharSlice& v, const char* data, + int size) { + EXPECT_EQ(size, v.size()); + for (size_t i = 0; i < v.size(); i++) { + EXPECT_EQ(data + i, &v[i]); + } +} +// A struct supplying the data(), mutable_data() and size() methods, just like +// e.g. proto2::RepeatedField. +struct RepeatedField { + std::vector storage; + const int* data() const { return storage.data(); } + int* mutable_data() { return storage.data(); } + int size() const { return storage.size(); } +}; + +// A struct supplying the data() (both mutable and const versions) and +// size(). It also supplies mutable_data() but we test that data() is selected +// instead. +struct ContainerWithOverloads { + std::vector storage; + std::vector wrong_storage; + const int* data() const { return storage.data(); } + int* data() { return storage.data(); } + // MutableArraySlice should not call mutable_data(), preferring data() + // instead. + int* mutable_data() { return wrong_storage.data(); } + int size() const { return storage.size(); } +}; + +// A struct supplying data() and size() methods. +struct ContainerWithShallowConstData { + std::vector storage; + int* data() const { return const_cast(storage.data()); } + int size() const { return storage.size(); } +}; + +TEST(IntSlice, Simple) { + for (int len = 0; len < 20; len++) { + IntVec vec; + Fill(&vec, len); + TestHelper(IntSlice(vec), vec); + TestHelper(IntSlice(vec.data(), vec.size()), vec); + } +} + +TEST(IntSlice, WithPosAndLen) { + IntVec vec; + Fill(&vec, 20); + for (size_t len = 0; len < vec.size(); len++) { + IntVec subvec(vec.begin(), vec.begin() + len); + TestImplicitConversion(IntSlice(vec, 0, len), subvec); + TestImplicitConversion(IntSlice(IntSlice(vec), 0, len), subvec); + } + EXPECT_EQ(0, IntSlice(vec, 0, 0).size()); + EXPECT_EQ(0, IntSlice(IntSlice(vec), 0, 0).size()); + TestImplicitConversion(IntSlice(vec, 0, IntSlice::npos), vec); +} + +TEST(IntSlice, Clear) { + for (int len = 0; len < 20; len++) { + IntVec vec; + Fill(&vec, len); + IntSlice v(vec); + v.clear(); + EXPECT_EQ(0, v.size()); + EXPECT_EQ(v.begin(), v.end()); + } +} + +TEST(IntSlice, Swap) { + for (int l1 = 0; l1 < 20; l1++) { + for (int l2 = 0; l2 < 20; l2++) { + IntVec avec, bvec; + Fill(&avec, l1); + Fill(&bvec, l2, 100); + IntSlice a(avec), b(bvec); + using std::swap; + swap(a, b); + EXPECT_EQ(l1, b.size()); + EXPECT_EQ(l2, a.size()); + for (int i = 0; i < l1; i++) { + EXPECT_EQ(i, b[i]); + } + for (int i = 0; i < l2; i++) { + EXPECT_EQ(100 + i, a[i]); + } + } + } +} + +TEST(IntSlice, ImplicitConversion) { + for (int len = 0; len < 20; len++) { + IntVec vec; + Fill(&vec, len); + IntSlice slice; + slice = vec; + TestImplicitConversion(vec, vec); + TestImplicitConversion(slice, vec); + TestImplicitConversion(IntSlice(vec.data(), vec.size()), vec); + } +} + +TEST(IntSlice, InlinedVectorConversion) { + for (int len = 0; len < 20; len++) { + InlinedVector inline_vec; + for (int i = 0; i < len; i++) { + inline_vec.push_back(i); + } + IntVec vec; + Fill(&vec, len); + IntSlice v = inline_vec; // Test assignment + static_cast(v); + TestImplicitConversion(inline_vec, vec); + } +} + +TEST(IntSlice, StaticArrayConversion) { + int array[20]; + IntVec vec; + Fill(&vec, TF_ARRAYSIZE(array)); + std::copy(vec.begin(), vec.end(), array); + IntSlice v = array; // Test assignment + static_cast(v); + TestImplicitConversion(array, vec); +} + +TEST(IntSlice, StdArrayConversion) { + std::array array; + IntVec vec; + Fill(&vec, array.size()); + std::copy(vec.begin(), vec.end(), array.begin()); + + // Check assignment. + { + IntSlice v = array; + static_cast(v); + } + + // Check sub-slice initialization. + { + IntSlice v = {array, 10, 15}; + static_cast(v); + } + + TestImplicitConversion(array, vec); +} + +// Values according to the Fill function. +static const int test_const_array[] = {0, 1, 2}; + +TEST(IntSlice, ConstStaticArrayConversion) { + IntVec vec; + Fill(&vec, TF_ARRAYSIZE(test_const_array)); + IntSlice v = test_const_array; // Test assignment + static_cast(v); + TestImplicitConversion(test_const_array, vec); +} + +TEST(IntSlice, RepeatedFieldConversion) { + RepeatedField repeated_field; + IntVec vec; + Fill(&vec, 20); + repeated_field.storage = vec; + IntSlice v = repeated_field; // Test assignment + static_cast(v); + TestImplicitConversion(repeated_field, vec); +} + +TEST(IntSlice, ContainerWithOverloadsConversion) { + ContainerWithOverloads container; + Fill(&container.storage, 20); + container.wrong_storage.resize(container.size()); + IntSlice v = container; // Test assignment + static_cast(v); + TestImplicitConversion(container, container.storage); +} + +TEST(IntSlice, ContainerWithShallowConstDataConversion) { + ContainerWithShallowConstData container; + Fill(&container.storage, 20); + IntSlice v = container; // Test assignment + static_cast(v); + TestImplicitConversion(container, container.storage); +} + +TEST(IntSlice, MutableIntSliceConversion) { + IntVec vec(20); + IntSlice slice = MutableIntSlice(&vec); + EXPECT_EQ(vec.size(), slice.size()); + EXPECT_EQ(vec.data(), slice.data()); +} + +TEST(IntSlice, Equality) { + IntVec vec1(20); + IntVec vec2(20); + // These two slices are from different vectors, but have the same + // size and have the same elements (right now). They should + // compare equal. + const IntSlice from1(vec1); + const IntSlice from2(vec2); + EXPECT_EQ(from1, from1); + EXPECT_EQ(from1, from2); + + // This verifies that MutableArraySlices can be compared freely with + // ArraySlices. + const MutableIntSlice mutable_from1(&vec1); + const MutableIntSlice mutable_from2(&vec2); + EXPECT_EQ(from1, mutable_from1); + EXPECT_EQ(mutable_from1, from1); + EXPECT_EQ(mutable_from1, mutable_from2); + EXPECT_EQ(mutable_from2, mutable_from1); + + // With a different size, the array slices should not be equal. + EXPECT_NE(from1, IntSlice(from1, 0, from1.size() - 1)); + + // With different contents, the array slices should not be equal. + ++vec2.back(); + EXPECT_NE(from1, from2); +} + +// Compile-asserts that the argument has the expected type. +template +void CheckType(const T& value) { + testing::StaticAssertTypeEq(); +} + +TEST(IntSlice, ExposesContainerTypesAndConsts) { + IntSlice slice; + const IntSlice const_slice; + CheckType(slice.begin()); + CheckType(const_slice.end()); + CheckType(const_slice.rbegin()); + CheckType(slice.rend()); + testing::StaticAssertTypeEq(); + testing::StaticAssertTypeEq(); + testing::StaticAssertTypeEq(); + EXPECT_EQ(static_cast(-1), IntSlice::npos); +} + +void TestEmpty(IntSlice slice) { ASSERT_TRUE(slice.empty()); } + +void TestRange(IntSlice slice, int from, int to) { + ASSERT_EQ(to - from + 1, slice.size()); + for (size_t i = 0; i < slice.size(); ++i) { + EXPECT_EQ(from + i, slice[i]); + } +} + +TEST(IntSlice, InitializerListConversion) { + TestEmpty({}); + TestRange({1}, 1, 1); + TestRange({10, 11, 12, 13}, 10, 13); +} + +TEST(CharSlice, StringConversion) { + IntVec vec; + Fill(&vec, 20); + string str(vec.begin(), vec.end()); + CharSlice v = str; // Test assignment + static_cast(v); + TestImplicitConversion(str, vec); +} + +TEST(IntPtrSlice, ConstConversion) { + int one = 1; + int two = 2; + std::vector vec; + vec.push_back(&one); + vec.push_back(&two); + ArraySlice v = vec; + ASSERT_EQ(2, v.size()); + EXPECT_EQ(&one, v[0]); + EXPECT_EQ(&two, v[1]); +} + +TEST(MutableIntSlice, Simple) { + for (int len = 0; len < 20; len++) { + IntVec vec(len); + MutableTestHelper(MutableIntSlice(&vec), vec.data(), len); + MutableTestHelper(MutableIntSlice(vec.data(), vec.size()), vec.data(), len); + } +} + +TEST(MutableIntSlice, WithPosAndLen) { + IntVec vec(20); + for (size_t len = 0; len < vec.size(); len++) { + TestImplicitConversion(MutableIntSlice(&vec, 0, len), vec.data(), len); + TestImplicitConversion(MutableIntSlice(MutableIntSlice(&vec), 0, len), + vec.data(), len); + } + EXPECT_EQ(0, MutableIntSlice(&vec, 0, 0).size()); + EXPECT_EQ(0, MutableIntSlice(MutableIntSlice(&vec), 0, 0).size()); + TestImplicitConversion(MutableIntSlice(&vec, 0, MutableIntSlice::npos), + vec.data(), vec.size()); +} + +TEST(MutableIntSlice, Clear) { + for (int len = 0; len < 20; len++) { + IntVec vec(len); + MutableIntSlice v(&vec); + v.clear(); + EXPECT_EQ(0, v.size()); + EXPECT_EQ(v.begin(), v.end()); + } +} + +TEST(MutableIntSlice, Swap) { + for (int l1 = 0; l1 < 20; l1++) { + for (int l2 = 0; l2 < 20; l2++) { + IntVec avec(l1), bvec(l2); + MutableIntSlice a(&avec), b(&bvec); + using std::swap; + swap(a, b); + EXPECT_EQ(l1, b.size()); + EXPECT_EQ(l2, a.size()); + for (int i = 0; i < l1; i++) { + EXPECT_EQ(&avec[i], &b[i]); + } + for (int i = 0; i < l2; i++) { + EXPECT_EQ(&bvec[i], &a[i]); + } + } + } +} + +TEST(MutableIntSlice, ImplicitConversion) { + for (int len = 0; len < 20; len++) { + IntVec vec(len); + MutableIntSlice slice; + slice = &vec; + TestImplicitConversion(&vec, vec.data(), len); + TestImplicitConversion(slice, vec.data(), len); + TestImplicitConversion(MutableIntSlice(vec.data(), vec.size()), vec.data(), + len); + } +} + +TEST(MutableIntSlice, InlinedVectorConversion) { + for (int len = 0; len < 20; len++) { + InlinedVector inline_vec; + for (int i = 0; i < len; i++) { + inline_vec.push_back(i); + } + MutableIntSlice v = &inline_vec; // Test assignment + static_cast(v); + TestImplicitConversion(&inline_vec, inline_vec.array(), inline_vec.size()); + } +} + +TEST(MutableIntSlice, StaticArrayConversion) { + int array[20]; + MutableIntSlice v = array; // Test assignment + static_cast(v); + TestImplicitConversion(array, array, TF_ARRAYSIZE(array)); +} + +TEST(MutableIntSlice, StdArrayConversion) { + std::array array; + + // Check assignment. + { + MutableIntSlice v = &array; + static_cast(v); + } + + // Check sub-slice initialization. + { + MutableIntSlice v = {&array, 10, 15}; + static_cast(v); + } + + TestImplicitConversion(&array, &array[0], array.size()); +} + +TEST(MutableIntSlice, RepeatedFieldConversion) { + RepeatedField repeated_field; + Fill(&repeated_field.storage, 20); + MutableIntSlice v = &repeated_field; // Test assignment + static_cast(v); + TestImplicitConversion(&repeated_field, repeated_field.storage.data(), + repeated_field.storage.size()); +} + +TEST(MutableIntSlice, ContainerWithOverloadsConversion) { + ContainerWithOverloads container; + Fill(&container.storage, 20); + container.wrong_storage.resize(container.size()); + MutableIntSlice v = &container; // Test assignment + static_cast(v); + TestImplicitConversion(&container, container.storage.data(), + container.storage.size()); +} + +TEST(MutableIntSlice, ContainerWithShallowConstDataConversion) { + ContainerWithShallowConstData container; + Fill(&container.storage, 20); + MutableIntSlice v = &container; // Test assignment + static_cast(v); + TestImplicitConversion(&container, container.storage.data(), + container.storage.size()); +} + +TEST(MutableIntSlice, TypedefsAndConstants) { + testing::StaticAssertTypeEq(); + testing::StaticAssertTypeEq(); + testing::StaticAssertTypeEq(); + testing::StaticAssertTypeEq(); + testing::StaticAssertTypeEq(); + + EXPECT_EQ(static_cast(-1), MutableIntSlice::npos); +} + +TEST(MutableIntSlice, IteratorsAndReferences) { + auto accept_pointer = [](int* x) {}; + auto accept_reference = [](int& x) {}; + auto accept_iterator = [](MutableIntSlice::iterator x) {}; + auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {}; + + int a[1]; + MutableIntSlice s = a; + + accept_pointer(s.data()); + accept_pointer(s.mutable_data()); + accept_iterator(s.begin()); + accept_iterator(s.end()); + accept_reverse_iterator(s.rbegin()); + accept_reverse_iterator(s.rend()); + + accept_reference(s[0]); + accept_reference(s.at(0)); + accept_reference(s.front()); + accept_reference(s.back()); +} + +TEST(MutableIntSlice, IteratorsAndReferences_Const) { + auto accept_pointer = [](int* x) {}; + auto accept_reference = [](int& x) {}; + auto accept_iterator = [](MutableIntSlice::iterator x) {}; + auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {}; + + int a[1]; + const MutableIntSlice s = a; + + accept_pointer(s.data()); + accept_pointer(s.mutable_data()); + accept_iterator(s.begin()); + accept_iterator(s.end()); + accept_reverse_iterator(s.rbegin()); + accept_reverse_iterator(s.rend()); + + accept_reference(s[0]); + accept_reference(s.at(0)); + accept_reference(s.front()); + accept_reference(s.back()); +} + +bool TestMutableOverload(MutableIntSlice slice) { return false; } + +bool TestMutableOverload(MutableCharSlice slice) { return true; } + +TEST(MutableCharSlice, StringConversion) { + for (int len = 0; len < 20; len++) { + string str(len, '\0'); + MutableCharSlice v = &str; // Test assignment + static_cast(v); + TestImplicitConversion(v, str.data(), str.size()); + } + // Verify that only the correct overload is feasible. Note that this would + // fail if the string ctor was declared simply as MutableArraySlice(string*), + // since in that case both overloads would be feasible. + string str; + EXPECT_TRUE(TestMutableOverload(&str)); +} + +} // namespace +} // namespace gtl +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/edit_distance.h b/tensorflow/core/lib/gtl/edit_distance.h new file mode 100644 index 0000000000..82b6c2299f --- /dev/null +++ b/tensorflow/core/lib/gtl/edit_distance.h @@ -0,0 +1,82 @@ +#ifndef TENSORFLOW_LIB_GTL_EDIT_DISTANCE_H_ +#define TENSORFLOW_LIB_GTL_EDIT_DISTANCE_H_ + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { +namespace gtl { + +// Calculate the Levenshtein Edit Distance between two contiguous +// sequences, s and t, of type T. +// +// The Levenshtein distance is a symmetric distance defined as the +// smallest number of insertions, deletions, and substitutions +// required to convert sequence s to t (and vice versa). +// Note, this distance does not consider transpositions. +// +// For more details and a reference implementation, see: +// https://en.wikipedia.org/wiki/Levenshtein_distance +// +// This implementation has time complexity O(|s|*|t|) +// and space complexity O(min(|s|, |t|)), where +// |x| := x.size() +// +// A simple call to LevenshteinDistance looks like: +// +// int64 dist = LevenshteinDistance("hi", "bye", std::equal_to()); +// +template +inline int64 LevenshteinDistance(const gtl::ArraySlice& s, + const gtl::ArraySlice& t, const Cmp& cmp) { + const int64 s_size = s.size(); + const int64 t_size = t.size(); + + if (s_size == 0) return t_size; + if (t_size == 0) return s_size; + if (s == t) return 0; + if (t_size > s_size) return LevenshteinDistance(t, s, cmp); + + // Create work vectors + gtl::InlinedVector scratch0(t_size + 1); + gtl::InlinedVector scratch1(t_size + 1); + + int64* previous = scratch0.data(); + int64* current = scratch1.data(); + + // Initialize previous row of distances + std::iota(scratch0.begin(), scratch0.end(), 0); + + for (int64 i = 0; i < s_size; ++i) { + // Swap current and previous rows for next iteration + std::swap(previous, current); + + // Calculate current row distances from previous row + current[0] = i + 1; + + // Fill in the rest of the row + for (int64 j = 0; j < t_size; ++j) { + const int64 cost = cmp(s[i], t[j]) ? 0 : 1; + current[j + 1] = + std::min(current[j] + 1, // deletion cost + std::min(previous[j + 1] + 1, // insertion cost + previous[j] + cost)); // substitution cost + } + } + + return current[t_size]; +} + +template +inline int64 LevenshteinDistance(const Container1& s, const Container2& t, + const Cmp& cmp) { + return LevenshteinDistance( + gtl::ArraySlice(s.data(), s.size()), + gtl::ArraySlice(t.data(), t.size()), + cmp); +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_EDIT_DISTANCE_H_ diff --git a/tensorflow/core/lib/gtl/edit_distance_test.cc b/tensorflow/core/lib/gtl/edit_distance_test.cc new file mode 100644 index 0000000000..0526ee0a05 --- /dev/null +++ b/tensorflow/core/lib/gtl/edit_distance_test.cc @@ -0,0 +1,125 @@ +#include "tensorflow/core/lib/gtl/edit_distance.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include + +namespace tensorflow { +namespace gtl { +namespace { + +class LevenshteinDistanceTest : public ::testing::Test { + protected: + std::vector empty_; + std::string s1_; + std::string s1234_; + std::string s567_; + std::string kilo_; + std::string kilogram_; + std::string mother_; + std::string grandmother_; + std::string lower_; + std::string upper_; + + void SetUp() override { + s1_ = "1"; + s1234_ = "1234"; + s567_ = "567"; + kilo_ = "kilo"; + kilogram_ = "kilogram"; + mother_ = "mother"; + grandmother_ = "grandmother"; + lower_ = "lower case"; + upper_ = "UPPER case"; + } +}; + +TEST_F(LevenshteinDistanceTest, BothEmpty) { + ASSERT_EQ(LevenshteinDistance(empty_, empty_, std::equal_to()), 0); +} + +TEST_F(LevenshteinDistanceTest, OneEmpty) { + ASSERT_EQ(LevenshteinDistance(s1234_, empty_, std::equal_to()), 4); + ASSERT_EQ(LevenshteinDistance(empty_, s567_, std::equal_to()), 3); +} + +TEST_F(LevenshteinDistanceTest, SingleElement) { + ASSERT_EQ(LevenshteinDistance(s1234_, s1_, std::equal_to()), 3); + ASSERT_EQ(LevenshteinDistance(s1_, s1234_, std::equal_to()), 3); +} + +TEST_F(LevenshteinDistanceTest, Prefix) { + ASSERT_EQ(LevenshteinDistance(kilo_, kilogram_, std::equal_to()), 4); + ASSERT_EQ(LevenshteinDistance(kilogram_, kilo_, std::equal_to()), 4); +} + +TEST_F(LevenshteinDistanceTest, Suffix) { + ASSERT_EQ(LevenshteinDistance(mother_, grandmother_, std::equal_to()), + 5); + ASSERT_EQ(LevenshteinDistance(grandmother_, mother_, std::equal_to()), + 5); +} + +TEST_F(LevenshteinDistanceTest, DifferentComparisons) { + ASSERT_EQ(LevenshteinDistance(lower_, upper_, std::equal_to()), 5); + ASSERT_EQ(LevenshteinDistance(upper_, lower_, std::equal_to()), 5); + ASSERT_EQ( + LevenshteinDistance(gtl::ArraySlice(lower_.data(), lower_.size()), + gtl::ArraySlice(upper_.data(), upper_.size()), + std::equal_to()), + 5); + auto no_case_cmp = [](char c1, char c2) { + return std::tolower(c1) == std::tolower(c2); + }; + ASSERT_EQ(LevenshteinDistance(lower_, upper_, no_case_cmp), 3); + ASSERT_EQ(LevenshteinDistance(upper_, lower_, no_case_cmp), 3); +} + +TEST_F(LevenshteinDistanceTest, Vectors) { + ASSERT_EQ( + LevenshteinDistance(std::string("algorithm"), std::string("altruistic"), + std::equal_to()), + 6); +} + +static void BM_EditDistanceHelper(int n, int len, bool completely_different) { + string a = + "The quick brown fox jumped over the lazy dog and on and on and on" + " Every good boy deserves fudge. In fact, this is a very long sentence " + " w/many bytes.."; + while (a.size() < static_cast(len)) { + a = a + a; + } + string b = a; + if (completely_different) { + for (size_t i = 0; i < b.size(); i++) { + b[i]++; + } + } + while (n-- > 0) { + LevenshteinDistance(gtl::ArraySlice(a.data(), len), + gtl::ArraySlice(b.data(), len), + std::equal_to()); + } +} + +static void BM_EditDistanceSame(int n, int len) { + BM_EditDistanceHelper(n, len, false); +} +static void BM_EditDistanceDiff(int n, int len) { + BM_EditDistanceHelper(n, len, true); +} + +BENCHMARK(BM_EditDistanceSame)->Arg(5); +BENCHMARK(BM_EditDistanceSame)->Arg(50); +BENCHMARK(BM_EditDistanceSame)->Arg(200); +BENCHMARK(BM_EditDistanceSame)->Arg(1000); +BENCHMARK(BM_EditDistanceDiff)->Arg(5); +BENCHMARK(BM_EditDistanceDiff)->Arg(50); +BENCHMARK(BM_EditDistanceDiff)->Arg(200); +BENCHMARK(BM_EditDistanceDiff)->Arg(1000); + +} // namespace +} // namespace gtl +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h new file mode 100644 index 0000000000..c23075129c --- /dev/null +++ b/tensorflow/core/lib/gtl/inlined_vector.h @@ -0,0 +1,839 @@ +// An InlinedVector is like a std::vector, except that storage +// for sequences of length <= N are provided inline without requiring +// any heap allocation. Typically N is very small (e.g., 4) so that +// sequences that are expected to be short do not require allocations. +// +// Only some of the std::vector<> operations are currently implemented. +// Other operations may be added as needed to facilitate migrating +// code that uses std::vector<> to InlinedVector<>. +// +// NOTE: If you want an inlined version to replace use of a +// std::vector, consider using util::bitmap::InlinedBitVector +// in util/bitmap/inlined_bitvector.h +// +// TODO(billydonahue): change size_t to size_type where appropriate. + +#ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ +#define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/gtl/manual_constructor.h" + +#include // NOLINT(build/include_order) + +namespace tensorflow { +namespace gtl { + +template > +class InlinedVector { + public: + typedef A allocator_type; + typedef typename allocator_type::value_type value_type; + typedef typename allocator_type::pointer pointer; + typedef typename allocator_type::const_pointer const_pointer; + typedef typename allocator_type::reference reference; + typedef typename allocator_type::const_reference const_reference; + typedef typename allocator_type::size_type size_type; + typedef typename allocator_type::difference_type difference_type; + typedef pointer iterator; + typedef const_pointer const_iterator; + + // Create an empty vector + InlinedVector(); + explicit InlinedVector(const allocator_type& alloc); + + // Create a vector with n copies of value_type(). + explicit InlinedVector(size_t n); + + // Create a vector with n copies of elem + InlinedVector(size_t n, const value_type& elem, + const allocator_type& alloc = allocator_type()); + + // Create and initialize with the elements [range_start .. range_end). + // The unused enable_if argument restricts this constructor so that it is + // elided when value_type is an integral type. This prevents ambiguous + // interpretation between a call to this constructor with two integral + // arguments and a call to the preceding (n, elem) constructor. + template + InlinedVector( + InputIterator range_start, InputIterator range_end, + const allocator_type& alloc = allocator_type(), + typename std::enable_if::value>::type* = + NULL) + : allocator_and_tag_(alloc) { + AppendRange(range_start, range_end); + } + + InlinedVector(std::initializer_list init, + const allocator_type& alloc = allocator_type()) + : allocator_and_tag_(alloc) { + AppendRange(init.begin(), init.end()); + } + + InlinedVector(const InlinedVector& v); + + ~InlinedVector() { clear(); } + + InlinedVector& operator=(const InlinedVector& v) { + // Optimized to avoid reallocation. + // Prefer reassignment to copy construction for elements. + if (size() < v.size()) { // grow + reserve(v.size()); + std::copy(v.begin(), v.begin() + size(), begin()); + std::copy(v.begin() + size(), v.end(), std::back_inserter(*this)); + } else { // maybe shrink + erase(begin() + v.size(), end()); + std::copy(v.begin(), v.end(), begin()); + } + return *this; + } + + size_t size() const { + return allocated() ? allocation().size() : tag().size(); + } + + bool empty() const { return (size() == 0); } + + // Return number of elements that can be stored in vector + // without requiring a reallocation of underlying memory + size_t capacity() const { return allocated() ? allocation().capacity() : N; } + + // Return a pointer to the underlying array. + // Only result[0,size()-1] are defined. + const_pointer data() const { + return allocated() ? allocated_space() : inlined_space(); + } + pointer data() { return allocated() ? allocated_space() : inlined_space(); } + + // An older name for the more standard-friendly .data(). + const_pointer array() const { return data(); } + pointer mutable_array() { return data(); } + + // Remove all elements + void clear() { + size_t s = size(); + if (allocated()) { + DestroyAllocated(allocated_space(), allocated_space() + s); + allocation().Dealloc(allocator()); + } else { + DestroyInlined(inlined_space(), inlined_space() + s); + } + tag() = Tag(); + } + + // Return the ith element + // REQUIRES: 0 <= i < size() + const value_type& at(size_t i) const { + DCHECK_LT(i, size()); + return array()[i]; + } + const value_type& operator[](size_t i) const { + DCHECK_LT(i, size()); + return array()[i]; + } + + // Return a non-const reference to the ith element + // REQUIRES: 0 <= i < size() + value_type& at(size_t i) { + DCHECK_LT(i, size()); + return mutable_array()[i]; + } + value_type& operator[](size_t i) { + DCHECK_LT(i, size()); + return mutable_array()[i]; + } + + value_type& back() { + DCHECK(!empty()); + return at(size() - 1); + } + + const value_type& back() const { + DCHECK(!empty()); + return at(size() - 1); + } + + value_type& front() { + DCHECK(!empty()); + return at(0); + } + + const value_type& front() const { + DCHECK(!empty()); + return at(0); + } + + // Append t to the vector. + // Increases size() by one. + // Amortized complexity: O(1) + // Worst-case complexity: O(size()) + void push_back(const value_type& t) { + size_t s = size(); + DCHECK_LE(s, capacity()); + if (s == capacity()) { + return GrowAndPushBack(t); + } + DCHECK_LT(s, capacity()); + + if (allocated()) { + ConstructAllocated(allocated_space() + s, t); + } else { + ConstructInlined(inlined_space() + s, t); + } + + set_size_internal(s + 1); + } + + void pop_back() { + DCHECK(!empty()); + size_t s = size(); + if (allocated()) { + DestroyAllocated(allocated_space() + s - 1, allocated_space() + s); + } else { + DestroyInlined(inlined_space() + s - 1, inlined_space() + s); + } + set_size_internal(s - 1); + } + + // Resizes the vector to contain "n" elements. + // If "n" is smaller than the initial size, extra elements are destroyed. + // If "n" is larger than the initial size, enough copies of "elem" + // are appended to increase the size to "n". If "elem" is omitted, + // new elements are value-initialized. + void resize(size_t n); + void resize(size_t n, const value_type& elem); + + iterator begin() { return mutable_array(); } + const_iterator begin() const { return array(); } + + iterator end() { return mutable_array() + size(); } + const_iterator end() const { return array() + size(); } + + iterator insert(iterator pos, const value_type& v); + + iterator erase(iterator pos) { + DCHECK_LT(pos, end()); + DCHECK_GE(pos, begin()); + std::copy(pos + 1, end(), pos); + pop_back(); + return pos; + } + + iterator erase(iterator first, iterator last); + + // Enlarges the underlying representation so it can hold at least + // "n" elements without reallocation. + // Does not change size() or the actual contents of the vector. + void reserve(size_t n) { + if (n > capacity()) { + // Make room for new elements + EnlargeBy(n - size()); + } + } + + // Swap the contents of *this with other. + // REQUIRES: value_type is swappable and copyable. + void swap(InlinedVector& other); + + allocator_type get_allocator() const { return allocator(); } + + private: + struct AllocatorTraits { + typedef typename allocator_type::value_type value_type; + typedef typename allocator_type::pointer pointer; + typedef typename allocator_type::size_type size_type; + + static void construct(allocator_type& a, // NOLINT(runtime/references) + pointer p) { + // Tricky: do we support non-copyable types, or support allocators + // that do special things with construct()? Non-copyable types are + // needed today, so they are more important. When we sort out the + // Android NDK C++11 problem, we will be able to use the proper + // std::allocator_traits::construct(p, ...). + // + // a.construct(p, value_type()); + new (p) value_type(); + } + static void construct(allocator_type& a, // NOLINT(runtime/references) + pointer p, const value_type& t) { + a.construct(p, t); + } + static void destroy(allocator_type& a, // NOLINT(runtime/references) + pointer p) { + a.destroy(p); + } + static pointer allocate(allocator_type& a, // NOLINT(runtime/references) + size_type n) { + return a.allocate(n); + } + static void deallocate(allocator_type& a, // NOLINT(runtime/references) + pointer p, size_type n) { + a.deallocate(p, n); + } + }; + + // If the vector is inlined, holds the size of the vector. + // If the vector is allocated, holds the special value kAllocated, + // and the size is stored in the vector's Allocation. + class Tag { + public: + Tag() : size_(0) {} + size_t size() const { return size_; } + void set_size(size_t n) { size_ = n; } + bool allocated() const { return size_ == kAllocated; } + void set_allocated() { size_ = kAllocated; } + + private: + static const size_t kAllocated = -1; + size_t size_; + }; + + // Derives from allocator_type to use the empty base class optimization. + // If the allocator_type is stateless, we can 'store' + // our instance of it for free. + class AllocatorAndTag : private allocator_type { + public: + explicit AllocatorAndTag(const allocator_type& a, Tag t = Tag()) + : allocator_type(a), tag_(t) {} + Tag& tag() { return tag_; } + const Tag& tag() const { return tag_; } + allocator_type& allocator() { return *this; } + const allocator_type& allocator() const { return *this; } + + private: + Tag tag_; + }; + + class Allocation { + public: + Allocation(allocator_type& a, // NOLINT(runtime/references) + size_t capacity) + : size_(0), + capacity_(capacity), + buffer_(AllocatorTraits::allocate(a, capacity_)) {} + + void Dealloc(allocator_type& a) { // NOLINT(runtime/references) + AllocatorTraits::deallocate(a, buffer(), capacity()); + } + + size_t size() const { return size_; } + void set_size(size_t s) { size_ = s; } + size_t capacity() const { return capacity_; } + const value_type* buffer() const { return buffer_; } + value_type* buffer() { return buffer_; } + + private: + size_t size_; + size_t capacity_; + value_type* buffer_; + }; + + const Tag& tag() const { return allocator_and_tag_.tag(); } + Tag& tag() { return allocator_and_tag_.tag(); } + + Allocation& allocation() { return *rep_.allocation_storage.allocation.get(); } + const Allocation& allocation() const { + return *rep_.allocation_storage.allocation.get(); + } + void init_allocation(const Allocation& allocation) { + rep_.allocation_storage.allocation.Init(allocation); + } + + value_type* inlined_space() { return rep_.inlined_storage.inlined[0].get(); } + const value_type* inlined_space() const { + return rep_.inlined_storage.inlined[0].get(); + } + + value_type* allocated_space() { return allocation().buffer(); } + const value_type* allocated_space() const { return allocation().buffer(); } + + const allocator_type& allocator() const { + return allocator_and_tag_.allocator(); + } + allocator_type& allocator() { return allocator_and_tag_.allocator(); } + + bool allocated() const { return tag().allocated(); } + void set_allocated() { return tag().set_allocated(); } + + void set_size_internal(size_t n) { + if (allocated()) { + allocation().set_size(n); + } else { + tag().set_size(n); + } + } + + // Enlarge the underlying representation so we can store size_ + delta elems. + // The size is not changed, and any newly added memory is not initialized. + void EnlargeBy(size_t delta); + + void ResetAllocation(Allocation new_allocation) { + if (allocated()) { + DestroyAllocated(allocated_space(), allocated_space() + size()); + DCHECK_EQ(begin(), allocated_space()); + allocation().Dealloc(allocator()); + allocation() = new_allocation; + } else { + DestroyInlined(inlined_space(), inlined_space() + size()); + init_allocation(new_allocation); // bug: only init once + set_allocated(); + } + } + + void GrowAndPushBack(const value_type& t) { + DCHECK_EQ(size(), capacity()); + const size_t s = size(); + + Allocation new_allocation(allocator(), 2 * capacity()); + new_allocation.set_size(s + 1); + + UninitializedCopyAllocated(array(), array() + s, new_allocation.buffer()); + ConstructAllocated(new_allocation.buffer() + s, t); + + ResetAllocation(new_allocation); + } + + void InitAssign(size_t n); + void InitAssign(size_t n, const value_type& t); + + void ConstructInlined(pointer p) { new (p) value_type(); } + + void ConstructInlined(pointer p, const value_type& t) { + new (p) value_type(t); + } + + void ConstructAllocated(pointer p) { + AllocatorTraits::construct(allocator(), p); + } + void ConstructAllocated(pointer p, const value_type& t) { + AllocatorTraits::construct(allocator(), p, t); + } + + template + void UninitializedCopyInlined(Iter src, Iter src_last, value_type* dst) { + std::uninitialized_copy(src, src_last, dst); + } + + template + void UninitializedCopyAllocated(Iter src, Iter src_last, value_type* dst) { + for (; src != src_last; ++dst, ++src) ConstructAllocated(dst, *src); + } + + void UninitializedFillInlined(value_type* dst, value_type* dst_last) { + for (; dst != dst_last; ++dst) ConstructInlined(dst); + } + void UninitializedFillInlined(value_type* dst, value_type* dst_last, + const value_type& t) { + std::uninitialized_fill(dst, dst_last, t); + } + + void UninitializedFillAllocated(value_type* dst, value_type* dst_last) { + for (; dst != dst_last; ++dst) ConstructAllocated(dst); + } + void UninitializedFillAllocated(value_type* dst, value_type* dst_last, + const value_type& t) { + for (; dst != dst_last; ++dst) ConstructAllocated(dst, t); + } + + // Destroy [ptr, ptr_last) in place. + void DestroyInlined(value_type* ptr, value_type* ptr_last); + void DestroyAllocated(value_type* ptr, value_type* ptr_last); + + template + void AppendRange(Iter first, Iter last, std::input_iterator_tag); + + // Faster path for forward iterators. + template + void AppendRange(Iter first, Iter last, std::forward_iterator_tag); + + template + void AppendRange(Iter first, Iter last); + + AllocatorAndTag allocator_and_tag_; + + // Either the inlined or allocated representation + union Rep { + // Use struct to perform indirection that solves a bizarre compilation + // error on Visual Studio (all known versions). + struct { + tensorflow::ManualConstructor inlined[N]; + } inlined_storage; + struct { + tensorflow::ManualConstructor allocation; + } allocation_storage; + } rep_; +}; + +template +const size_t InlinedVector::Tag::kAllocated; + +template +inline void swap(InlinedVector& a, InlinedVector& b) { + a.swap(b); +} + +template +inline bool operator==(const InlinedVector& a, + const InlinedVector& b) { + return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin()); +} + +template +inline bool operator!=(const InlinedVector& a, + const InlinedVector& b) { + return !(a == b); +} + +template +inline bool operator<(const InlinedVector& a, + const InlinedVector& b) { + return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end()); +} + +template +inline bool operator>(const InlinedVector& a, + const InlinedVector& b) { + return b < a; +} + +template +inline bool operator<=(const InlinedVector& a, + const InlinedVector& b) { + return !(b < a); +} + +template +inline bool operator>=(const InlinedVector& a, + const InlinedVector& b) { + return !(a < b); +} + +// ======================================== +// Implementation + +template +inline InlinedVector::InlinedVector() + : allocator_and_tag_(allocator_type()) {} + +template +inline InlinedVector::InlinedVector(const allocator_type& alloc) + : allocator_and_tag_(alloc) {} + +template +inline InlinedVector::InlinedVector(size_t n) + : allocator_and_tag_(allocator_type()) { + InitAssign(n); +} + +template +inline InlinedVector::InlinedVector(size_t n, const value_type& elem, + const allocator_type& alloc) + : allocator_and_tag_(alloc) { + InitAssign(n, elem); +} + +template +inline InlinedVector::InlinedVector(const InlinedVector& v) + : allocator_and_tag_(v.allocator()) { + reserve(v.size()); + if (allocated()) { + UninitializedCopyAllocated(v.begin(), v.end(), allocated_space()); + } else { + UninitializedCopyInlined(v.begin(), v.end(), inlined_space()); + } + set_size_internal(v.size()); +} + +template +inline void InlinedVector::InitAssign(size_t n, const value_type& t) { + if (n > static_cast(N)) { + Allocation new_allocation(allocator(), n); + init_allocation(new_allocation); + set_allocated(); + UninitializedFillAllocated(allocated_space(), allocated_space() + n, t); + } else { + UninitializedFillInlined(inlined_space(), inlined_space() + n, t); + } + set_size_internal(n); +} + +template +inline void InlinedVector::InitAssign(size_t n) { + if (n > static_cast(N)) { + Allocation new_allocation(allocator(), n); + init_allocation(new_allocation); + set_allocated(); + UninitializedFillAllocated(allocated_space(), allocated_space() + n); + } else { + UninitializedFillInlined(inlined_space(), inlined_space() + n); + } + set_size_internal(n); +} + +template +inline void InlinedVector::resize(size_t n) { + size_t s = size(); + if (n < s) { + erase(begin() + n, end()); + return; + } + reserve(n); + DCHECK_GE(capacity(), n); + + // Fill new space with elements constructed in-place. + if (allocated()) { + UninitializedFillAllocated(allocated_space() + s, allocated_space() + n); + } else { + UninitializedFillInlined(inlined_space() + s, inlined_space() + n); + } + set_size_internal(n); +} + +template +inline void InlinedVector::resize(size_t n, const value_type& elem) { + size_t s = size(); + if (n < s) { + erase(begin() + n, end()); + return; + } + reserve(n); + DCHECK_GE(capacity(), n); + + // Fill new space with copies of 'elem'. + if (allocated()) { + UninitializedFillAllocated(allocated_space() + s, allocated_space() + n, + elem); + } else { + UninitializedFillInlined(inlined_space() + s, inlined_space() + n, elem); + } + set_size_internal(n); +} + +template +typename InlinedVector::iterator InlinedVector::insert( + iterator pos, const value_type& v) { + DCHECK_GE(pos, begin()); + DCHECK_LE(pos, end()); + if (pos == end()) { + push_back(v); + return end() - 1; + } + size_t s = size(); + size_t idx = std::distance(begin(), pos); + if (s == capacity()) { + EnlargeBy(1); + } + CHECK_LT(s, capacity()); + pos = begin() + idx; // Reset 'pos' into a post-enlarge iterator. + + if (allocated()) { + ConstructAllocated(allocated_space() + s, *(allocated_space() + s - 1)); + std::copy_backward(pos, allocated_space() + s - 1, allocated_space() + s); + } else { + ConstructInlined(inlined_space() + s, *(inlined_space() + s - 1)); + std::copy_backward(pos, inlined_space() + s - 1, inlined_space() + s); + } + + *pos = v; + + set_size_internal(s + 1); + return pos; +} + +template +typename InlinedVector::iterator InlinedVector::erase( + iterator first, iterator last) { + DCHECK_LE(begin(), first); + DCHECK_LE(first, last); + DCHECK_LE(last, end()); + + size_t s = size(); + ptrdiff_t erase_gap = std::distance(first, last); + + if (allocated()) { + std::copy(last, allocated_space() + s, first); + DestroyAllocated(allocated_space() + s - erase_gap, allocated_space() + s); + } else { + std::copy(last, inlined_space() + s, first); + DestroyInlined(inlined_space() + s - erase_gap, inlined_space() + s); + } + + set_size_internal(size() - erase_gap); + + return first; +} + +template +void InlinedVector::swap(InlinedVector& other) { + using std::swap; // Augment ADL with std::swap. + if (&other == this) { + return; + } + if (allocated() && other.allocated()) { + // Both out of line, so just swap the tag, allocation, and allocator. + swap(tag(), other.tag()); + swap(allocation(), other.allocation()); + swap(allocator(), other.allocator()); + return; + } + if (!allocated() && !other.allocated()) { + // Both inlined: swap up to smaller size, then move remaining elements. + InlinedVector* a = this; + InlinedVector* b = &other; + if (size() < other.size()) { + swap(a, b); + } + + const size_t a_size = a->size(); + const size_t b_size = b->size(); + DCHECK_GE(a_size, b_size); + // 'a' is larger. Swap the elements up to the smaller array size. + std::swap_ranges(a->inlined_space(), a->inlined_space() + b_size, + b->inlined_space()); + + // Move the remaining elements: A[b_size,a_size) -> B[b_size,a_size) + b->UninitializedCopyInlined(a->inlined_space() + b_size, + a->inlined_space() + a_size, + b->inlined_space() + b_size); + a->DestroyInlined(a->inlined_space() + b_size, a->inlined_space() + a_size); + + swap(a->tag(), b->tag()); + swap(a->allocator(), b->allocator()); + DCHECK_EQ(b->size(), a_size); + DCHECK_EQ(a->size(), b_size); + return; + } + // One is out of line, one is inline. + // We first move the elements from the inlined vector into the + // inlined space in the other vector. We then put the other vector's + // pointer/capacity into the originally inlined vector and swap + // the tags. + InlinedVector* a = this; + InlinedVector* b = &other; + if (a->allocated()) { + swap(a, b); + } + DCHECK(!a->allocated()); + DCHECK(b->allocated()); + const size_t a_size = a->size(); + const size_t b_size = b->size(); + + // Made Local copies of size(), don't need tag() accurate anymore + swap(a->tag(), b->tag()); + + // Copy b_allocation out before b's union gets clobbered by inline_space. + Allocation b_allocation = b->allocation(); + + b->UninitializedCopyInlined(a->inlined_space(), a->inlined_space() + a_size, + b->inlined_space()); + a->DestroyInlined(a->inlined_space(), a->inlined_space() + a_size); + + a->allocation() = b_allocation; + + if (a->allocator() != b->allocator()) { + swap(a->allocator(), b->allocator()); + } + + DCHECK_EQ(b->size(), a_size); + DCHECK_EQ(a->size(), b_size); +} + +template +void InlinedVector::EnlargeBy(size_t delta) { + const size_t s = size(); + DCHECK_LE(s, capacity()); + + size_t target = std::max(static_cast(N), s + delta); + + // Compute new capacity by repeatedly doubling current capacity + // TODO(psrc): Check and avoid overflow? + size_t new_capacity = capacity(); + while (new_capacity < target) { + new_capacity <<= 1; + } + + Allocation new_allocation(allocator(), new_capacity); + new_allocation.set_size(s); + + UninitializedCopyAllocated(array(), array() + s, new_allocation.buffer()); + + ResetAllocation(new_allocation); +} + +template +inline void InlinedVector::DestroyInlined(value_type* ptr, + value_type* ptr_last) { + for (value_type* p = ptr; p != ptr_last; ++p) { + p->~value_type(); + } + +// Overwrite unused memory with 0xab so we can catch uninitialized usage. +// Cast to void* to tell the compiler that we don't care that we might be +// scribbling on a vtable pointer. +#ifndef NDEBUG + if (ptr != ptr_last) { + memset(reinterpret_cast(ptr), 0xab, sizeof(*ptr) * (ptr_last - ptr)); + } +#endif +} + +template +inline void InlinedVector::DestroyAllocated(value_type* ptr, + value_type* ptr_last) { + for (value_type* p = ptr; p != ptr_last; ++p) { + AllocatorTraits::destroy(allocator(), p); + } + +// Overwrite unused memory with 0xab so we can catch uninitialized usage. +// Cast to void* to tell the compiler that we don't care that we might be +// scribbling on a vtable pointer. +#ifndef NDEBUG + if (ptr != ptr_last) { + memset(reinterpret_cast(ptr), 0xab, sizeof(*ptr) * (ptr_last - ptr)); + } +#endif +} + +template +template +inline void InlinedVector::AppendRange(Iter first, Iter last, + std::input_iterator_tag) { + std::copy(first, last, std::back_inserter(*this)); +} + +template +template +inline void InlinedVector::AppendRange(Iter first, Iter last, + std::forward_iterator_tag) { + typedef typename std::iterator_traits::difference_type Length; + Length length = std::distance(first, last); + reserve(size() + length); + if (allocated()) { + UninitializedCopyAllocated(first, last, allocated_space() + size()); + } else { + UninitializedCopyInlined(first, last, inlined_space() + size()); + } + set_size_internal(size() + length); +} + +template +template +inline void InlinedVector::AppendRange(Iter first, Iter last) { + typedef typename std::iterator_traits::iterator_category IterTag; + AppendRange(first, last, IterTag()); +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc new file mode 100644 index 0000000000..ec5fe1eaa8 --- /dev/null +++ b/tensorflow/core/lib/gtl/inlined_vector_test.cc @@ -0,0 +1,905 @@ +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +#include +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include + +namespace tensorflow { + +typedef tensorflow::gtl::InlinedVector IntVec; + +// A type that counts number of live occurrences of the type +static int64 instances = 0; +class Instance { + public: + int value_; + explicit Instance(int x) : value_(x) { instances++; } + Instance(const Instance& x) : value_(x.value_) { instances++; } + ~Instance() { instances--; } + + friend inline void swap(Instance& a, Instance& b) { + using std::swap; + swap(a.value_, b.value_); + } + + friend std::ostream& operator<<(std::ostream& o, const Instance& v) { + return o << "[value:" << v.value_ << "]"; + } +}; + +typedef tensorflow::gtl::InlinedVector InstanceVec; + +// A simple reference counted class to make sure that the proper elements are +// destroyed in the erase(begin, end) test. +class RefCounted { + public: + RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); } + + RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) { + VLOG(5) << "[RefCounted: copy" + << " from count @" << v.count_ << "]"; + Ref(); + } + + ~RefCounted() { + Unref(); + count_ = NULL; + } + + friend void swap(RefCounted& a, RefCounted& b) { + using std::swap; + swap(a.value_, b.value_); + swap(a.count_, b.count_); + } + + RefCounted& operator=(RefCounted v) { + using std::swap; + swap(*this, v); + return *this; + } + + void Ref() const { + CHECK(count_ != NULL); + ++(*count_); + VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]"; + } + + void Unref() const { + --(*count_); + CHECK_GE(*count_, 0); + VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]"; + } + + int count() const { return *count_; } + + friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) { + return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]"; + } + + int value_; + int* count_; +}; + +typedef tensorflow::gtl::InlinedVector RefCountedVec; + +// A class with a vtable pointer +class Dynamic { + public: + virtual ~Dynamic() {} + + friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) { + return o << "[Dynamic]"; + } +}; + +typedef tensorflow::gtl::InlinedVector DynamicVec; + +// Append 0..len-1 to *v +static void Fill(IntVec* v, int len, int offset = 0) { + for (int i = 0; i < len; i++) { + v->push_back(i + offset); + } +} + +static IntVec Fill(int len, int offset = 0) { + IntVec v; + Fill(&v, len, offset); + return v; +} + +TEST(IntVec, SimpleOps) { + for (int len = 0; len < 20; len++) { + IntVec v; + const IntVec& cv = v; // const alias + + Fill(&v, len); + EXPECT_EQ(len, v.size()); + EXPECT_LE(len, v.capacity()); + + for (int i = 0; i < len; i++) { + EXPECT_EQ(i, v[i]); + } + EXPECT_EQ(v.begin(), v.array()); + EXPECT_EQ(v.begin(), v.mutable_array()); + + EXPECT_EQ(v.begin(), v.data()); + EXPECT_EQ(cv.begin(), cv.data()); + + int counter = 0; + for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) { + EXPECT_EQ(counter, *iter); + counter++; + } + EXPECT_EQ(counter, len); + + counter = 0; + for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) { + EXPECT_EQ(counter, *iter); + counter++; + } + EXPECT_EQ(counter, len); + + if (len > 0) { + EXPECT_EQ(0, v.front()); + EXPECT_EQ(len - 1, v.back()); + v.pop_back(); + EXPECT_EQ(len - 1, v.size()); + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(i, v[i]); + } + } + } +} + +TEST(IntVec, Erase) { + for (int len = 1; len < 20; len++) { + for (int i = 0; i < len; ++i) { + IntVec v; + Fill(&v, len); + v.erase(v.begin() + i); + EXPECT_EQ(len - 1, v.size()); + for (int j = 0; j < i; ++j) { + EXPECT_EQ(j, v[j]); + } + for (int j = i; j < len - 1; ++j) { + EXPECT_EQ(j + 1, v[j]); + } + } + } +} + +// At the end of this test loop, the elements between [erase_begin, erase_end) +// should have reference counts == 0, and all others elements should have +// reference counts == 1. +TEST(RefCountedVec, EraseBeginEnd) { + for (int len = 1; len < 20; ++len) { + for (int erase_begin = 0; erase_begin < len; ++erase_begin) { + for (int erase_end = erase_begin; erase_end <= len; ++erase_end) { + std::vector counts(len, 0); + RefCountedVec v; + for (int i = 0; i < len; ++i) { + v.push_back(RefCounted(i, &counts[i])); + } + + int erase_len = erase_end - erase_begin; + + v.erase(v.begin() + erase_begin, v.begin() + erase_end); + + EXPECT_EQ(len - erase_len, v.size()); + + // Check the elements before the first element erased. + for (int i = 0; i < erase_begin; ++i) { + EXPECT_EQ(i, v[i].value_); + } + + // Check the elements after the first element erased. + for (size_t i = erase_begin; i < v.size(); ++i) { + EXPECT_EQ(i + erase_len, v[i].value_); + } + + // Check that the elements at the beginning are preserved. + for (int i = 0; i < erase_begin; ++i) { + EXPECT_EQ(1, counts[i]); + } + + // Check that the erased elements are destroyed + for (int i = erase_begin; i < erase_end; ++i) { + EXPECT_EQ(0, counts[i]); + } + + // Check that the elements at the end are preserved. + for (int i = erase_end; i < len; ++i) { + EXPECT_EQ(1, counts[i]); + } + } + } + } +} + +struct NoDefaultCtor { + explicit NoDefaultCtor(int /* x */) {} +}; +struct NoCopy { + NoCopy() {} + NoCopy(const NoCopy& /* x */) = delete; +}; +struct NoAssign { + NoAssign() {} + NoAssign& operator=(const NoAssign& /* x */) = delete; +}; +TEST(InlinedVectorTest, NoDefaultCtor) { + tensorflow::gtl::InlinedVector v(10, NoDefaultCtor(2)); + (void)v; +} +TEST(InlinedVectorTest, NoCopy) { + tensorflow::gtl::InlinedVector v(10); + (void)v; +} +TEST(InlinedVectorTest, NoAssign) { + tensorflow::gtl::InlinedVector v(10); + (void)v; +} + +TEST(IntVec, Insert) { + for (int len = 0; len < 20; len++) { + for (int pos = 0; pos <= len; pos++) { + IntVec v; + Fill(&v, len); + v.insert(v.begin() + pos, 9999); + EXPECT_EQ(v.size(), len + 1); + for (int i = 0; i < pos; i++) { + EXPECT_EQ(v[i], i); + } + EXPECT_EQ(v[pos], 9999); + for (size_t i = pos + 1; i < v.size(); i++) { + EXPECT_EQ(v[i], i - 1); + } + } + } +} + +TEST(RefCountedVec, InsertConstructorDestructor) { + // Make sure the proper construction/destruction happen during insert + // operations. + for (int len = 0; len < 20; len++) { + SCOPED_TRACE(len); + for (int pos = 0; pos <= len; pos++) { + SCOPED_TRACE(pos); + std::vector counts(len, 0); + RefCountedVec v; + for (int i = 0; i < len; ++i) { + SCOPED_TRACE(i); + v.push_back(RefCounted(i, &counts[i])); + } + + for (auto elem : counts) { + EXPECT_EQ(1, elem); + } + + int inserted_count = 0; + RefCounted insert_element(9999, &inserted_count); + EXPECT_EQ(1, inserted_count); + v.insert(v.begin() + pos, insert_element); + EXPECT_EQ(2, inserted_count); + // Check that the elements at the end are preserved. + for (auto elem : counts) { + EXPECT_EQ(1, elem); + } + EXPECT_EQ(2, inserted_count); + } + } +} + +TEST(IntVec, Resize) { + for (int len = 0; len < 20; len++) { + IntVec v; + Fill(&v, len); + + // Try resizing up and down by k elements + static const int kResizeElem = 1000000; + for (int k = 0; k < 10; k++) { + // Enlarging resize + v.resize(len + k, kResizeElem); + EXPECT_EQ(len + k, v.size()); + EXPECT_LE(len + k, v.capacity()); + for (int i = 0; i < len + k; i++) { + if (i < len) { + EXPECT_EQ(i, v[i]); + } else { + EXPECT_EQ(kResizeElem, v[i]); + } + } + + // Shrinking resize + v.resize(len, kResizeElem); + EXPECT_EQ(len, v.size()); + EXPECT_LE(len, v.capacity()); + for (int i = 0; i < len; i++) { + EXPECT_EQ(i, v[i]); + } + } + } +} + +TEST(IntVec, InitWithLength) { + for (int len = 0; len < 20; len++) { + IntVec v(len, 7); + EXPECT_EQ(len, v.size()); + EXPECT_LE(len, v.capacity()); + for (int i = 0; i < len; i++) { + EXPECT_EQ(7, v[i]); + } + } +} + +TEST(IntVec, CopyConstructorAndAssignment) { + for (int len = 0; len < 20; len++) { + IntVec v; + Fill(&v, len); + EXPECT_EQ(len, v.size()); + EXPECT_LE(len, v.capacity()); + + IntVec v2(v); + EXPECT_EQ(v, v2); + + for (int start_len = 0; start_len < 20; start_len++) { + IntVec v3; + Fill(&v3, start_len, 99); // Add dummy elements that should go away + v3 = v; + EXPECT_EQ(v, v3); + } + } +} + +TEST(OverheadTest, Storage) { + // Check for size overhead. + // In particular, ensure that std::allocator doesn't cost anything to store. + // The union should be absorbing some of the allocation bookkeeping overhead + // in the larger vectors, leaving only the size_ field as overhead. + using tensorflow::gtl::InlinedVector; + EXPECT_EQ(3 * sizeof(int*), + sizeof(InlinedVector) - 1 * sizeof(int*)); + EXPECT_EQ(2 * sizeof(int*), + sizeof(InlinedVector) - 2 * sizeof(int*)); + EXPECT_EQ(1 * sizeof(int*), + sizeof(InlinedVector) - 3 * sizeof(int*)); + EXPECT_EQ(1 * sizeof(int*), + sizeof(InlinedVector) - 4 * sizeof(int*)); + EXPECT_EQ(1 * sizeof(int*), + sizeof(InlinedVector) - 5 * sizeof(int*)); + EXPECT_EQ(1 * sizeof(int*), + sizeof(InlinedVector) - 6 * sizeof(int*)); + EXPECT_EQ(1 * sizeof(int*), + sizeof(InlinedVector) - 7 * sizeof(int*)); + EXPECT_EQ(1 * sizeof(int*), + sizeof(InlinedVector) - 8 * sizeof(int*)); +} + +TEST(IntVec, Clear) { + for (int len = 0; len < 20; len++) { + SCOPED_TRACE(len); + IntVec v; + Fill(&v, len); + v.clear(); + EXPECT_EQ(0, v.size()); + EXPECT_EQ(v.begin(), v.end()); + } +} + +TEST(IntVec, Reserve) { + for (size_t len = 0; len < 20; len++) { + IntVec v; + Fill(&v, len); + + for (size_t newlen = 0; newlen < 100; newlen++) { + const int* start_rep = v.array(); + v.reserve(newlen); + const int* final_rep = v.array(); + if (newlen <= len) { + EXPECT_EQ(start_rep, final_rep); + } + EXPECT_LE(newlen, v.capacity()); + + // Filling up to newlen should not change rep + while (v.size() < newlen) { + v.push_back(0); + } + EXPECT_EQ(final_rep, v.array()); + } + } +} + +template +static std::vector Vec(const T& src) { + std::vector result; + for (const auto& elem : src) { + result.push_back(elem); + } + return result; +} + +TEST(IntVec, SelfRefPushBack) { + std::vector std_v; + tensorflow::gtl::InlinedVector v; + const string s = "A very long string to ensure heap."; + std_v.push_back(s); + v.push_back(s); + for (int i = 0; i < 20; ++i) { + EXPECT_EQ(std_v, Vec(v)); + + v.push_back(v.back()); + std_v.push_back(std_v.back()); + } + EXPECT_EQ(std_v, Vec(v)); +} + +TEST(IntVec, Swap) { + for (int l1 = 0; l1 < 20; l1++) { + SCOPED_TRACE(l1); + for (int l2 = 0; l2 < 20; l2++) { + SCOPED_TRACE(l2); + IntVec a = Fill(l1, 0); + IntVec b = Fill(l2, 100); + { + using std::swap; + swap(a, b); + } + EXPECT_EQ(l1, b.size()); + EXPECT_EQ(l2, a.size()); + for (int i = 0; i < l1; i++) { + SCOPED_TRACE(i); + EXPECT_EQ(i, b[i]); + } + for (int i = 0; i < l2; i++) { + SCOPED_TRACE(i); + EXPECT_EQ(100 + i, a[i]); + } + } + } +} + +TEST(InstanceVec, Swap) { + for (int l1 = 0; l1 < 20; l1++) { + for (int l2 = 0; l2 < 20; l2++) { + InstanceVec a, b; + for (int i = 0; i < l1; i++) a.push_back(Instance(i)); + for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i)); + EXPECT_EQ(l1 + l2, instances); + { + using std::swap; + swap(a, b); + } + EXPECT_EQ(l1 + l2, instances); + EXPECT_EQ(l1, b.size()); + EXPECT_EQ(l2, a.size()); + for (int i = 0; i < l1; i++) { + EXPECT_EQ(i, b[i].value_); + } + for (int i = 0; i < l2; i++) { + EXPECT_EQ(100 + i, a[i].value_); + } + } + } +} + +TEST(IntVec, EqualAndNotEqual) { + IntVec a, b; + EXPECT_TRUE(a == b); + EXPECT_FALSE(a != b); + + a.push_back(3); + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + + b.push_back(3); + EXPECT_TRUE(a == b); + EXPECT_FALSE(a != b); + + b.push_back(7); + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + + a.push_back(6); + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + + a.clear(); + b.clear(); + for (int i = 0; i < 100; i++) { + a.push_back(i); + b.push_back(i); + EXPECT_TRUE(a == b); + EXPECT_FALSE(a != b); + + b[i] = b[i] + 1; + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + + b[i] = b[i] - 1; // Back to before + EXPECT_TRUE(a == b); + EXPECT_FALSE(a != b); + } +} + +TEST(IntVec, RelationalOps) { + IntVec a, b; + EXPECT_FALSE(a < b); + EXPECT_FALSE(b < a); + EXPECT_FALSE(a > b); + EXPECT_FALSE(b > a); + EXPECT_TRUE(a <= b); + EXPECT_TRUE(b <= a); + EXPECT_TRUE(a >= b); + EXPECT_TRUE(b >= a); + b.push_back(3); + EXPECT_TRUE(a < b); + EXPECT_FALSE(b < a); + EXPECT_FALSE(a > b); + EXPECT_TRUE(b > a); + EXPECT_TRUE(a <= b); + EXPECT_FALSE(b <= a); + EXPECT_FALSE(a >= b); + EXPECT_TRUE(b >= a); +} + +TEST(InstanceVec, CountConstructorsDestructors) { + const int start = instances; + for (int len = 0; len < 20; len++) { + InstanceVec v; + for (int i = 0; i < len; i++) { + v.push_back(Instance(i)); + } + EXPECT_EQ(start + len, instances); + + { // Copy constructor should create 'len' more instances. + InstanceVec v_copy(v); + EXPECT_EQ(start + len + len, instances); + } + EXPECT_EQ(start + len, instances); + + // Enlarging resize() must construct some objects + v.resize(len + 10, Instance(100)); + EXPECT_EQ(start + len + 10, instances); + + // Shrinking resize() must destroy some objects + v.resize(len, Instance(100)); + EXPECT_EQ(start + len, instances); + + // reserve() must not increase the number of initialized objects + v.reserve(len + 1000); + EXPECT_EQ(start + len, instances); + + // pop_back() and erase() must destroy one object + if (len > 0) { + v.pop_back(); + EXPECT_EQ(start + len - 1, instances); + if (!v.empty()) { + v.erase(v.begin()); + EXPECT_EQ(start + len - 2, instances); + } + } + } + EXPECT_EQ(start, instances); +} + +TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) { + const int start = instances; + for (int len = 0; len < 20; len++) { + for (int longorshort = 0; longorshort <= 1; ++longorshort) { + InstanceVec longer, shorter; + for (int i = 0; i < len; i++) { + longer.push_back(Instance(i)); + shorter.push_back(Instance(i)); + } + longer.push_back(Instance(len)); + EXPECT_EQ(start + len + len + 1, instances); + + if (longorshort) { + shorter = longer; + EXPECT_EQ(start + (len + 1) + (len + 1), instances); + } else { + longer = shorter; + EXPECT_EQ(start + len + len, instances); + } + } + } + EXPECT_EQ(start, instances); +} + +TEST(RangedConstructor, SimpleType) { + std::vector source_v = {4, 5, 6}; + // First try to fit in inline backing + tensorflow::gtl::InlinedVector v(source_v.begin(), source_v.end()); + EXPECT_EQ(3, v.size()); + EXPECT_EQ(4, v.capacity()); // Indication that we're still on inlined storage + EXPECT_EQ(4, v[0]); + EXPECT_EQ(5, v[1]); + EXPECT_EQ(6, v[2]); + + // Now, force a re-allocate + tensorflow::gtl::InlinedVector realloc_v(source_v.begin(), + source_v.end()); + EXPECT_EQ(3, realloc_v.size()); + EXPECT_LT(2, realloc_v.capacity()); + EXPECT_EQ(4, realloc_v[0]); + EXPECT_EQ(5, realloc_v[1]); + EXPECT_EQ(6, realloc_v[2]); +} + +TEST(RangedConstructor, ComplexType) { + // We also use a list here to pass a different flavor of iterator (e.g. not + // random-access). + std::list source_v = {Instance(0)}; + + // First try to fit in inline backing + tensorflow::gtl::InlinedVector v(source_v.begin(), + source_v.end()); + EXPECT_EQ(1, v.size()); + EXPECT_EQ(1, v.capacity()); // Indication that we're still on inlined storage + EXPECT_EQ(0, v[0].value_); + + std::list source_v2 = {Instance(0), Instance(1)}; + // Now, force a re-allocate + tensorflow::gtl::InlinedVector realloc_v(source_v2.begin(), + source_v2.end()); + EXPECT_EQ(2, realloc_v.size()); + EXPECT_LT(1, realloc_v.capacity()); + EXPECT_EQ(0, realloc_v[0].value_); + EXPECT_EQ(1, realloc_v[1].value_); +} + +TEST(RangedConstructor, ElementsAreConstructed) { + std::vector source_v = {"cat", "dog"}; + + // Force expansion and re-allocation of v. Ensures that when the vector is + // expanded that new elements are constructed. + tensorflow::gtl::InlinedVector v(source_v.begin(), source_v.end()); + EXPECT_EQ("cat", v[0]); + EXPECT_EQ("dog", v[1]); +} + +TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) { + auto vec = tensorflow::gtl::InlinedVector{4, 5, 6}; + EXPECT_EQ(3, vec.size()); + EXPECT_EQ(4, vec.capacity()); + EXPECT_EQ(4, vec[0]); + EXPECT_EQ(5, vec[1]); + EXPECT_EQ(6, vec[2]); +} + +TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) { + auto vec = tensorflow::gtl::InlinedVector{4, 5, 6}; + EXPECT_EQ(3, vec.size()); + EXPECT_LE(3, vec.capacity()); + EXPECT_EQ(4, vec[0]); + EXPECT_EQ(5, vec[1]); + EXPECT_EQ(6, vec[2]); +} + +TEST(InitializerListConstructor, DisparateTypesInList) { + EXPECT_EQ((std::vector{-7, 8}), + Vec(tensorflow::gtl::InlinedVector{-7, 8ULL})); + + EXPECT_EQ( + (std::vector{"foo", "bar"}), + Vec(tensorflow::gtl::InlinedVector{"foo", string("bar")})); +} + +TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) { + auto vec = tensorflow::gtl::InlinedVector{Instance(0)}; + EXPECT_EQ(1, vec.size()); + EXPECT_EQ(1, vec.capacity()); + EXPECT_EQ(0, vec[0].value_); +} + +TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) { + auto vec = + tensorflow::gtl::InlinedVector{Instance(0), Instance(1)}; + EXPECT_EQ(2, vec.size()); + EXPECT_LE(2, vec.capacity()); + EXPECT_EQ(0, vec[0].value_); + EXPECT_EQ(1, vec[1].value_); +} + +TEST(DynamicVec, DynamicVecCompiles) { + DynamicVec v; + (void)v; +} + +#ifdef INLINED_VECTOR_HAS_ALLOC +TEST(AllocatorSupportTest, Constructors) { + typedef STLCountingAllocator MyAlloc; + typedef tensorflow::gtl::InlinedVector AllocVec; + const int ia[] = {0, 1, 2, 3, 4, 5, 6, 7}; + int64 allocated = 0; + MyAlloc alloc(&allocated); + { AllocVec TF_ATTRIBUTE_UNUSED v; } + { AllocVec TF_ATTRIBUTE_UNUSED v(alloc); } + { AllocVec TF_ATTRIBUTE_UNUSED v(ia, ia + arraysize(ia), alloc); } +#ifdef LANG_CXX11 + { AllocVec TF_ATTRIBUTE_UNUSED v({1, 2, 3}, alloc); } +#endif // LANG_CXX11 +} + +TEST(AllocatorSupportTest, CountAllocations) { + typedef STLCountingAllocator MyAlloc; + typedef tensorflow::gtl::InlinedVector AllocVec; + const int ia[] = {0, 1, 2, 3, 4, 5, 6, 7}; + int64 allocated = 0; + MyAlloc alloc(&allocated); + { + AllocVec TF_ATTRIBUTE_UNUSED v(ia, ia + 4, alloc); + EXPECT_THAT(allocated, 0); + } + EXPECT_THAT(allocated, 0); + { + AllocVec TF_ATTRIBUTE_UNUSED v(ia, ia + arraysize(ia), alloc); + EXPECT_THAT(allocated, v.size() * sizeof(int)); + } + EXPECT_THAT(allocated, 0); +} + +TEST(AllocatorSupportTest, SwapBothAllocated) { + typedef STLCountingAllocator MyAlloc; + typedef tensorflow::gtl::InlinedVector AllocVec; + int64 allocated1 = 0; + int64 allocated2 = 0; + { + const std::vector ia1 = {0, 1, 2, 3, 4, 5, 6, 7}; + const std::vector ia2 = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + MyAlloc a1(&allocated1); + MyAlloc a2(&allocated2); + AllocVec v1(ia1.data(), ia1.data() + ia1.size(), a1); + AllocVec v2(ia2.data(), ia2.data() + ia2.size(), a2); + EXPECT_LT(v1.capacity(), v2.capacity()); + EXPECT_THAT(allocated1, v1.capacity() * sizeof(int)); + EXPECT_THAT(allocated2, v2.capacity() * sizeof(int)); + v1.swap(v2); + EXPECT_EQ(ia2, Vec(v1)); + EXPECT_EQ(ia1, Vec(v2)); + EXPECT_THAT(allocated1, v2.capacity() * sizeof(int)); + EXPECT_THAT(allocated2, v1.capacity() * sizeof(int)); + } + EXPECT_THAT(allocated1, 0); + EXPECT_THAT(allocated2, 0); +} + +TEST(AllocatorSupportTest, SwapOneAllocated) { + typedef STLCountingAllocator MyAlloc; + typedef tensorflow::gtl::InlinedVector AllocVec; + int64 allocated1 = 0; + int64 allocated2 = 0; + { + const std::vector ia1 = {0, 1, 2, 3, 4, 5, 6, 7}; + const std::vector ia2 = {0, 1, 2, 3}; + MyAlloc a1(&allocated1); + MyAlloc a2(&allocated2); + AllocVec v1(ia1.data(), ia1.data() + ia1.size(), a1); + AllocVec v2(ia2.data(), ia2.data() + ia2.size(), a2); + EXPECT_THAT(allocated1, v1.capacity() * sizeof(int)); + EXPECT_THAT(allocated2, 0); + v1.swap(v2); + EXPECT_EQ(ia2, Vec(v1)); + EXPECT_EQ(ia1, Vec(v2)); + EXPECT_THAT(allocated1, v2.capacity() * sizeof(int)); + EXPECT_THAT(allocated2, 0); + EXPECT_TRUE(v2.get_allocator() == a1); + EXPECT_TRUE(v1.get_allocator() == a2); + } + EXPECT_THAT(allocated1, 0); + EXPECT_THAT(allocated2, 0); +} +#endif // INLINED_VECTOR_HAS_ALLOC + +static void BM_InlinedVectorFill(int iters, int len) { + for (int i = 0; i < iters; i++) { + IntVec v; + for (int j = 0; j < len; j++) { + v.push_back(j); + } + } + testing::BytesProcessed((static_cast(iters) * len) * sizeof(int)); +} +BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024); + +static void BM_InlinedVectorFillRange(int iters, int len) { + std::unique_ptr ia(new int[len]); + for (int j = 0; j < len; j++) { + ia[j] = j; + } + for (int i = 0; i < iters; i++) { + IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len); + } + testing::BytesProcessed((static_cast(iters) * len) * sizeof(int)); +} +BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024); + +static void BM_StdVectorFill(int iters, int len) { + for (int i = 0; i < iters; i++) { + std::vector v; + for (int j = 0; j < len; j++) { + v.push_back(j); + } + } + testing::BytesProcessed((static_cast(iters) * len) * sizeof(int)); +} +BENCHMARK(BM_StdVectorFill)->Range(0, 1024); + +namespace { +struct Buffer { // some arbitrary structure for benchmarking. + char* base; + int length; + int capacity; + void* user_data; +}; +} // anonymous namespace + +static void BM_InlinedVectorTenAssignments(int iters, int len) { + typedef tensorflow::gtl::InlinedVector BufferVec; + + BufferVec src; + src.resize(len); + + iters *= 10; + BufferVec dst; + for (int i = 0; i < iters; i++) { + dst = src; + } +} +BENCHMARK(BM_InlinedVectorTenAssignments) + ->Arg(0) + ->Arg(1) + ->Arg(2) + ->Arg(3) + ->Arg(4) + ->Arg(20); + +static void BM_CreateFromInitializerList(int iters) { + for (; iters > 0; iters--) { + tensorflow::gtl::InlinedVector x{1, 2, 3}; + (void)x[0]; + } +} +BENCHMARK(BM_CreateFromInitializerList); + +namespace { + +struct LargeSwappable { + LargeSwappable() : d_(1024, 17) {} + ~LargeSwappable() {} + LargeSwappable(const LargeSwappable& o) : d_(o.d_) {} + + friend void swap(LargeSwappable& a, LargeSwappable& b) { + using std::swap; + swap(a.d_, b.d_); + } + + LargeSwappable& operator=(LargeSwappable o) { + using std::swap; + swap(*this, o); + return *this; + } + + std::vector d_; +}; + +} // namespace + +static void BM_LargeSwappableElements(int iters, int len) { + typedef tensorflow::gtl::InlinedVector Vec; + Vec a(len); + Vec b; + while (--iters >= 0) { + using std::swap; + swap(a, b); + } +} +BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024); + +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/int_type.h b/tensorflow/core/lib/gtl/int_type.h new file mode 100644 index 0000000000..d3fcb08d38 --- /dev/null +++ b/tensorflow/core/lib/gtl/int_type.h @@ -0,0 +1,343 @@ +// #status: LEGACY +// #category: Miscellaneous +// #summary: Integral types; prefer util/intops/strong_int.h +// #bugs: Infrastructure > C++ Library Team > util +// +// IntType is a simple template class mechanism for defining "logical" +// integer-like class types that support many of the same functionalities +// as native integer types, but which prevent assignment, construction, and +// other operations from other similar integer-like types. Essentially, the +// template class IntType (where ValueType assumes +// valid scalar types such as int, uint, int32, etc) has the additional +// property that it cannot be assigned to or constructed from other IntTypes +// or native integer types of equal or implicitly convertible type. +// +// The class is useful for preventing mingling of integer variables with +// different logical roles or units. Unfortunately, C++ provides relatively +// good type-safety for user-defined classes but not for integer types. It is +// essentially up to the user to use nice variable names and comments to prevent +// accidental mismatches, such as confusing a user-index with a group-index or a +// time-in-milliseconds with a time-in-seconds. The use of typedefs are limited +// in that regard as they do not enforce type-safety. +// +// USAGE ----------------------------------------------------------------------- +// +// DEFINE_INT_TYPE(IntTypeName, ValueType); +// +// where: +// IntTypeName: is the desired (unique) name for the "logical" integer type. +// ValueType: is one of the integral types as defined by base::is_integral +// (see base/type_traits.h). +// +// DISALLOWED OPERATIONS / TYPE-SAFETY ENFORCEMENT ----------------------------- +// +// Consider these definitions and variable declarations: +// DEFINE_INT_TYPE(GlobalDocID, int64); +// DEFINE_INT_TYPE(LocalDocID, int64); +// GlobalDocID global; +// LocalDocID local; +// +// The class IntType prevents: +// +// 1) Assignments of other IntTypes with different IntTypeNames. +// +// global = local; <-- Fails to compile! +// local = global; <-- Fails to compile! +// +// 2) Explicit/implicit conversion from an IntType to another IntType. +// +// LocalDocID l(global); <-- Fails to compile! +// LocalDocID l = global; <-- Fails to compile! +// +// void GetGlobalDoc(GlobalDocID global) { } +// GetGlobalDoc(global); <-- Compiles fine, types match! +// GetGlobalDoc(local); <-- Fails to compile! +// +// 3) Implicit conversion from an IntType to a native integer type. +// +// void GetGlobalDoc(int64 global) { ... +// GetGlobalDoc(global); <-- Fails to compile! +// GetGlobalDoc(local); <-- Fails to compile! +// +// void GetLocalDoc(int32 local) { ... +// GetLocalDoc(global); <-- Fails to compile! +// GetLocalDoc(local); <-- Fails to compile! +// +// +// SUPPORTED OPERATIONS -------------------------------------------------------- +// +// The following operators are supported: unary: ++ (both prefix and postfix), +// +, -, ! (logical not), ~ (one's complement); comparison: ==, !=, <, <=, >, +// >=; numerical: +, -, *, /; assignment: =, +=, -=, /=, *=; stream: <<. Each +// operator allows the same IntTypeName and the ValueType to be used on +// both left- and right-hand sides. +// +// It also supports an accessor value() returning the stored value as ValueType, +// and a templatized accessor value() method that serves as syntactic sugar +// for static_cast(var.value()). These accessors are useful when assigning +// the stored value into protocol buffer fields and using it as printf args. +// +// The class also defines a hash functor that allows the IntType to be used +// as key to hashable containers such as std::unordered_map and +// std::unordered_set. +// +// We suggest using the IntTypeIndexedContainer wrapper around FixedArray and +// STL vector (see int-type-indexed-container.h) if an IntType is intended to +// be used as an index into these containers. These wrappers are indexed in a +// type-safe manner using IntTypes to ensure type-safety. +// +// NB: this implementation does not attempt to abide by or enforce dimensional +// analysis on these scalar types. +// +// EXAMPLES -------------------------------------------------------------------- +// +// DEFINE_INT_TYPE(GlobalDocID, int64); +// GlobalDocID global = 3; +// cout << global; <-- Prints 3 to stdout. +// +// for (GlobalDocID i(0); i < global; ++i) { +// cout << i; +// } <-- Print(ln)s 0 1 2 to stdout +// +// DEFINE_INT_TYPE(LocalDocID, int64); +// LocalDocID local; +// cout << local; <-- Prints 0 to stdout it default +// initializes the value to 0. +// +// local = 5; +// local *= 2; +// LocalDocID l(local); +// cout << l + local; <-- Prints 20 to stdout. +// +// GenericSearchRequest request; +// request.set_doc_id(global.value()); <-- Uses value() to extract the value +// from the IntType class. +// +// REMARKS --------------------------------------------------------------------- +// +// The following bad usage is permissible although discouraged. Essentially, it +// involves using the value*() accessors to extract the native integer type out +// of the IntType class. Keep in mind that the primary reason for the IntType +// class is to prevent *accidental* mingling of similar logical integer types -- +// and not type casting from one type to another. +// +// DEFINE_INT_TYPE(GlobalDocID, int64); +// DEFINE_INT_TYPE(LocalDocID, int64); +// GlobalDocID global; +// LocalDocID local; +// +// global = local.value(); <-- Compiles fine. +// +// void GetGlobalDoc(GlobalDocID global) { ... +// GetGlobalDoc(local.value()); <-- Compiles fine. +// +// void GetGlobalDoc(int64 global) { ... +// GetGlobalDoc(local.value()); <-- Compiles fine. + +#ifndef TENSORFLOW_LIB_GTL_INT_TYPE_H_ +#define TENSORFLOW_LIB_GTL_INT_TYPE_H_ + +#include +#include +#include +#include // NOLINT +#include + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace gtl { + +template +class IntType; + +// Defines the IntType using value_type and typedefs it to int_type_name. +// The struct int_type_name ## _tag_ trickery is needed to ensure that a new +// type is created per int_type_name. +#define TF_LIB_GTL_DEFINE_INT_TYPE(int_type_name, value_type) \ + struct int_type_name##_tag_ {}; \ + typedef ::tensorflow::gtl::IntType \ + int_type_name; + +// Holds an integer value (of type ValueType) and behaves as a ValueType by +// exposing assignment, unary, comparison, and arithmetic operators. +// +// The template parameter IntTypeName defines the name for the int type and must +// be unique within a binary (the convenient DEFINE_INT_TYPE macro at the end of +// the file generates a unique IntTypeName). The parameter ValueType defines +// the integer type value (see supported list above). +// +// This class is NOT thread-safe. +template +class IntType { + public: + typedef _ValueType ValueType; // for non-member operators + typedef IntType ThisType; // Syntactic sugar. + + // Note that this may change from time to time without notice. + struct Hasher { + size_t operator()(const IntType& arg) const { + return static_cast(arg.value()); + } + }; + + public: + // Default c'tor initializing value_ to 0. + constexpr IntType() : value_(0) {} + // C'tor explicitly initializing from a ValueType. + constexpr explicit IntType(ValueType value) : value_(value) {} + + // IntType uses the default copy constructor, destructor and assign operator. + // The defaults are sufficient and omitting them allows the compiler to add + // the move constructor/assignment. + + // -- ACCESSORS -------------------------------------------------------------- + // The class provides a value() accessor returning the stored ValueType value_ + // as well as a templatized accessor that is just a syntactic sugar for + // static_cast(var.value()); + constexpr ValueType value() const { return value_; } + + template + constexpr ValType value() const { + return static_cast(value_); + } + + // -- UNARY OPERATORS -------------------------------------------------------- + ThisType& operator++() { // prefix ++ + ++value_; + return *this; + } + const ThisType operator++(int v) { // postfix ++ + ThisType temp(*this); + ++value_; + return temp; + } + ThisType& operator--() { // prefix -- + --value_; + return *this; + } + const ThisType operator--(int v) { // postfix -- + ThisType temp(*this); + --value_; + return temp; + } + + constexpr bool operator!() const { return value_ == 0; } + constexpr const ThisType operator+() const { return ThisType(value_); } + constexpr const ThisType operator-() const { return ThisType(-value_); } + constexpr const ThisType operator~() const { return ThisType(~value_); } + +// -- ASSIGNMENT OPERATORS --------------------------------------------------- +// We support the following assignment operators: =, +=, -=, *=, /=, <<=, >>= +// and %= for both ThisType and ValueType. +#define INT_TYPE_ASSIGNMENT_OP(op) \ + ThisType& operator op(const ThisType& arg_value) { \ + value_ op arg_value.value(); \ + return *this; \ + } \ + ThisType& operator op(ValueType arg_value) { \ + value_ op arg_value; \ + return *this; \ + } + INT_TYPE_ASSIGNMENT_OP(+= ); + INT_TYPE_ASSIGNMENT_OP(-= ); + INT_TYPE_ASSIGNMENT_OP(*= ); + INT_TYPE_ASSIGNMENT_OP(/= ); + INT_TYPE_ASSIGNMENT_OP(<<= ); // NOLINT + INT_TYPE_ASSIGNMENT_OP(>>= ); // NOLINT + INT_TYPE_ASSIGNMENT_OP(%= ); +#undef INT_TYPE_ASSIGNMENT_OP + + ThisType& operator=(ValueType arg_value) { + value_ = arg_value; + return *this; + } + + private: + // The integer value of type ValueType. + ValueType value_; + + static_assert(std::is_integral::value, "invalid integer type"); +} TF_PACKED; + +// -- NON-MEMBER STREAM OPERATORS ---------------------------------------------- +// We provide the << operator, primarily for logging purposes. Currently, there +// seems to be no need for an >> operator. +template +std::ostream& operator<<(std::ostream& os, // NOLINT + IntType arg) { + return os << arg.value(); +} + +// -- NON-MEMBER ARITHMETIC OPERATORS ------------------------------------------ +// We support only the +, -, *, and / operators with the same IntType and +// ValueType types. The reason is to allow simple manipulation on these IDs +// when used as indices in vectors and arrays. +// +// NB: Although it is possible to do IntType * IntType and IntType / IntType, +// it is probably non-sensical from a dimensionality analysis perspective. +#define INT_TYPE_ARITHMETIC_OP(op) \ + template \ + static inline constexpr IntType operator op( \ + IntType id_1, \ + IntType id_2) { \ + return IntType(id_1.value() op id_2.value()); \ + } \ + template \ + static inline constexpr IntType operator op( \ + IntType id, \ + typename IntType::ValueType arg_val) { \ + return IntType(id.value() op arg_val); \ + } \ + template \ + static inline constexpr IntType operator op( \ + typename IntType::ValueType arg_val, \ + IntType id) { \ + return IntType(arg_val op id.value()); \ + } +INT_TYPE_ARITHMETIC_OP(+); +INT_TYPE_ARITHMETIC_OP(-); +INT_TYPE_ARITHMETIC_OP(*); +INT_TYPE_ARITHMETIC_OP(/ ); +INT_TYPE_ARITHMETIC_OP(<< ); // NOLINT +INT_TYPE_ARITHMETIC_OP(>> ); // NOLINT +INT_TYPE_ARITHMETIC_OP(% ); +#undef INT_TYPE_ARITHMETIC_OP + +// -- NON-MEMBER COMPARISON OPERATORS ------------------------------------------ +// Static inline comparison operators. We allow all comparison operators among +// the following types (OP \in [==, !=, <, <=, >, >=]: +// IntType OP IntType +// IntType OP ValueType +// ValueType OP IntType +#define INT_TYPE_COMPARISON_OP(op) \ + template \ + static inline constexpr bool operator op( \ + IntType id_1, \ + IntType id_2) { \ + return id_1.value() op id_2.value(); \ + } \ + template \ + static inline constexpr bool operator op( \ + IntType id, \ + typename IntType::ValueType val) { \ + return id.value() op val; \ + } \ + template \ + static inline constexpr bool operator op( \ + typename IntType::ValueType val, \ + IntType id) { \ + return val op id.value(); \ + } +INT_TYPE_COMPARISON_OP(== ); // NOLINT +INT_TYPE_COMPARISON_OP(!= ); // NOLINT +INT_TYPE_COMPARISON_OP(< ); // NOLINT +INT_TYPE_COMPARISON_OP(<= ); // NOLINT +INT_TYPE_COMPARISON_OP(> ); // NOLINT +INT_TYPE_COMPARISON_OP(>= ); // NOLINT +#undef INT_TYPE_COMPARISON_OP + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_INT_TYPE_H_ diff --git a/tensorflow/core/lib/gtl/int_type_test.cc b/tensorflow/core/lib/gtl/int_type_test.cc new file mode 100644 index 0000000000..694886d345 --- /dev/null +++ b/tensorflow/core/lib/gtl/int_type_test.cc @@ -0,0 +1,282 @@ +// Unit test cases for IntType. + +#include +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/gtl/int_type.h" +#include + +namespace tensorflow { + +TF_LIB_GTL_DEFINE_INT_TYPE(Int8_IT, int8); +TF_LIB_GTL_DEFINE_INT_TYPE(UInt8_IT, uint8); +TF_LIB_GTL_DEFINE_INT_TYPE(Int16_IT, int16); +TF_LIB_GTL_DEFINE_INT_TYPE(UInt16_IT, uint16); +TF_LIB_GTL_DEFINE_INT_TYPE(Int32_IT, int32); +TF_LIB_GTL_DEFINE_INT_TYPE(Int64_IT, int64); +TF_LIB_GTL_DEFINE_INT_TYPE(UInt32_IT, uint32); +TF_LIB_GTL_DEFINE_INT_TYPE(UInt64_IT, uint64); +TF_LIB_GTL_DEFINE_INT_TYPE(Long_IT, long); // NOLINT + +template +class IntTypeTest : public ::testing::Test { + public: + typedef IntType_Type T; +}; + +// All tests below will be executed on all supported IntTypes. +typedef ::testing::Types SupportedIntTypes; + +TYPED_TEST_CASE(IntTypeTest, SupportedIntTypes); + +TYPED_TEST(IntTypeTest, TestInitialization) { + constexpr typename TestFixture::T a; + constexpr typename TestFixture::T b(1); + constexpr typename TestFixture::T c(b); + EXPECT_EQ(0, a); // default initialization to 0 + EXPECT_EQ(1, b); + EXPECT_EQ(1, c); +} + +TYPED_TEST(IntTypeTest, TestOperators) { + typename TestFixture::T a(0); + typename TestFixture::T b(1); + typename TestFixture::T c(2); + constexpr typename TestFixture::T d(3); + constexpr typename TestFixture::T e(4); + + // On all EXPECT_EQ below, we use the accessor value() as to not invoke the + // comparison operators which must themselves be tested. + + // -- UNARY OPERATORS -------------------------------------------------------- + EXPECT_EQ(0, (a++).value()); + EXPECT_EQ(2, (++a).value()); + EXPECT_EQ(2, (a--).value()); + EXPECT_EQ(0, (--a).value()); + + EXPECT_EQ(true, !a); + EXPECT_EQ(false, !b); + static_assert(!d == false, "Unary operator! failed"); + + EXPECT_EQ(a.value(), +a); + static_assert(+d == d.value(), "Unary operator+ failed"); + EXPECT_EQ(-a.value(), -a); + static_assert(-d == -d.value(), "Unary operator- failed"); + EXPECT_EQ(~a.value(), ~a); // ~zero + EXPECT_EQ(~b.value(), ~b); // ~non-zero + static_assert(~d == ~d.value(), "Unary operator~ failed"); + + // -- ASSIGNMENT OPERATORS --------------------------------------------------- + // We test all assignment operators using IntType and constant as arguments. + // We also test the return from the operators. + // From same IntType + c = a = b; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + // From constant + c = b = 2; + EXPECT_EQ(2, b.value()); + EXPECT_EQ(2, c.value()); + // From same IntType + c = a += b; + EXPECT_EQ(3, a.value()); + EXPECT_EQ(3, c.value()); + c = a -= b; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + c = a *= b; + EXPECT_EQ(2, a.value()); + EXPECT_EQ(2, c.value()); + c = a /= b; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + c = a <<= b; + EXPECT_EQ(4, a.value()); + EXPECT_EQ(4, c.value()); + c = a >>= b; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + c = a %= b; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + // From constant + c = a += 2; + EXPECT_EQ(3, a.value()); + EXPECT_EQ(3, c.value()); + c = a -= 2; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + c = a *= 2; + EXPECT_EQ(2, a.value()); + EXPECT_EQ(2, c.value()); + c = a /= 2; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + c = a <<= 2; + EXPECT_EQ(4, a.value()); + EXPECT_EQ(4, c.value()); + c = a >>= 2; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + c = a %= 2; + EXPECT_EQ(1, a.value()); + EXPECT_EQ(1, c.value()); + + // -- COMPARISON OPERATORS --------------------------------------------------- + a = 0; + b = 1; + + EXPECT_FALSE(a == b); + EXPECT_TRUE(a == 0); // NOLINT + EXPECT_FALSE(1 == a); // NOLINT + static_assert(d == d, "operator== failed"); + static_assert(d == 3, "operator== failed"); + static_assert(3 == d, "operator== failed"); + EXPECT_TRUE(a != b); + EXPECT_TRUE(a != 1); // NOLINT + EXPECT_FALSE(0 != a); // NOLINT + static_assert(d != e, "operator!= failed"); + static_assert(d != 4, "operator!= failed"); + static_assert(4 != d, "operator!= failed"); + EXPECT_TRUE(a < b); + EXPECT_TRUE(a < 1); // NOLINT + EXPECT_FALSE(0 < a); // NOLINT + static_assert(d < e, "operator< failed"); + static_assert(d < 4, "operator< failed"); + static_assert(3 < e, "operator< failed"); + EXPECT_TRUE(a <= b); + EXPECT_TRUE(a <= 1); // NOLINT + EXPECT_TRUE(0 <= a); // NOLINT + static_assert(d <= e, "operator<= failed"); + static_assert(d <= 4, "operator<= failed"); + static_assert(3 <= e, "operator<= failed"); + EXPECT_FALSE(a > b); + EXPECT_FALSE(a > 1); // NOLINT + EXPECT_FALSE(0 > a); // NOLINT + static_assert(e > d, "operator> failed"); + static_assert(e > 3, "operator> failed"); + static_assert(4 > d, "operator> failed"); + EXPECT_FALSE(a >= b); + EXPECT_FALSE(a >= 1); // NOLINT + EXPECT_TRUE(0 >= a); // NOLINT + static_assert(e >= d, "operator>= failed"); + static_assert(e >= 3, "operator>= failed"); + static_assert(4 >= d, "operator>= failed"); + + // -- BINARY OPERATORS ------------------------------------------------------- + a = 1; + b = 3; + EXPECT_EQ(4, (a + b).value()); + EXPECT_EQ(4, (a + 3).value()); + EXPECT_EQ(4, (1 + b).value()); + static_assert((d + e).value() == 7, "Binary operator+ failed"); + static_assert((d + 4).value() == 7, "Binary operator+ failed"); + static_assert((3 + e).value() == 7, "Binary operator+ failed"); + EXPECT_EQ(2, (b - a).value()); + EXPECT_EQ(2, (b - 1).value()); + EXPECT_EQ(2, (3 - a).value()); + static_assert((e - d).value() == 1, "Binary operator- failed"); + static_assert((e - 3).value() == 1, "Binary operator- failed"); + static_assert((4 - d).value() == 1, "Binary operator- failed"); + EXPECT_EQ(3, (a * b).value()); + EXPECT_EQ(3, (a * 3).value()); + EXPECT_EQ(3, (1 * b).value()); + static_assert((d * e).value() == 12, "Binary operator* failed"); + static_assert((d * 4).value() == 12, "Binary operator* failed"); + static_assert((3 * e).value() == 12, "Binary operator* failed"); + EXPECT_EQ(0, (a / b).value()); + EXPECT_EQ(0, (a / 3).value()); + EXPECT_EQ(0, (1 / b).value()); + static_assert((d / e).value() == 0, "Binary operator/ failed"); + static_assert((d / 4).value() == 0, "Binary operator/ failed"); + static_assert((3 / e).value() == 0, "Binary operator/ failed"); + EXPECT_EQ(8, (a << b).value()); + EXPECT_EQ(8, (a << 3).value()); + EXPECT_EQ(8, (1 << b).value()); + static_assert((d << e).value() == 48, "Binary operator<< failed"); + static_assert((d << 4).value() == 48, "Binary operator<< failed"); + static_assert((3 << e).value() == 48, "Binary operator<< failed"); + b = 8; + EXPECT_EQ(4, (b >> a).value()); + EXPECT_EQ(4, (b >> 1).value()); + EXPECT_EQ(4, (8 >> a).value()); + static_assert((d >> e).value() == 0, "Binary operator>> failed"); + static_assert((d >> 4).value() == 0, "Binary operator>> failed"); + static_assert((3 >> e).value() == 0, "Binary operator>> failed"); + b = 3; + a = 2; + EXPECT_EQ(1, (b % a).value()); + EXPECT_EQ(1, (b % 2).value()); + EXPECT_EQ(1, (3 % a).value()); + static_assert((e % d).value() == 1, "Binary operator% failed"); + static_assert((e % 3).value() == 1, "Binary operator% failed"); + static_assert((4 % d).value() == 1, "Binary operator% failed"); +} + +TYPED_TEST(IntTypeTest, TestHashFunctor) { + std::unordered_map map; + typename TestFixture::T a(0); + map[a] = 'c'; + EXPECT_EQ('c', map[a]); + map[++a] = 'o'; + EXPECT_EQ('o', map[a]); + + typename TestFixture::T b(a); + EXPECT_EQ(typename TestFixture::T::Hasher()(a), + typename TestFixture::T::Hasher()(b)); +} + +// Tests the use of the templatized value accessor that performs static_casts. +// We use -1 to force casting in unsigned integers. +TYPED_TEST(IntTypeTest, TestValueAccessor) { + constexpr typename TestFixture::T::ValueType i = -1; + constexpr typename TestFixture::T int_type(i); + EXPECT_EQ(i, int_type.value()); + static_assert(int_type.value() == i, "value() failed"); + // The use of the keyword 'template' (suggested by Clang) is only necessary + // as this code is part of a template class. Weird syntax though. Good news + // is that only int_type.value() is needed in most code. + EXPECT_EQ(static_cast(i), int_type.template value()); + EXPECT_EQ(static_cast(i), int_type.template value()); + EXPECT_EQ(static_cast(i), int_type.template value()); + EXPECT_EQ(static_cast(i), int_type.template value()); + EXPECT_EQ(static_cast(i), int_type.template value()); + EXPECT_EQ(static_cast(i), int_type.template value()); + EXPECT_EQ(static_cast(i), int_type.template value()); + EXPECT_EQ(static_cast(i), int_type.template value()); // NOLINT + static_assert(int_type.template value() == static_cast(i), + "value() failed"); +} + +TYPED_TEST(IntTypeTest, TestMove) { + // Check that the int types have move constructor/assignment. + // We do this by composing a struct with an int type and a unique_ptr. This + // struct can't be copied due to the unique_ptr, so it must be moved. + // If this compiles, it means that the int types have move operators. + struct NotCopyable { + typename TestFixture::T inttype; + std::unique_ptr ptr; + + static NotCopyable Make(int i) { + NotCopyable f; + f.inttype = typename TestFixture::T(i); + f.ptr.reset(new int(i)); + return f; + } + }; + + // Test move constructor. + NotCopyable foo = NotCopyable::Make(123); + EXPECT_EQ(123, foo.inttype); + EXPECT_EQ(123, *foo.ptr); + + // Test move assignment. + foo = NotCopyable::Make(321); + EXPECT_EQ(321, foo.inttype); + EXPECT_EQ(321, *foo.ptr); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/iterator_range.h b/tensorflow/core/lib/gtl/iterator_range.h new file mode 100644 index 0000000000..baec85c40a --- /dev/null +++ b/tensorflow/core/lib/gtl/iterator_range.h @@ -0,0 +1,49 @@ +// This provides a very simple, boring adaptor for a begin and end iterator +// into a range type. This should be used to build range views that work well +// with range based for loops and range based constructors. +// +// Note that code here follows more standards-based coding conventions as it +// is mirroring proposed interfaces for standardization. +// +// Converted from chandlerc@'s code to Google style by joshl@. + +#ifndef TENSORFLOW_LIB_GTL_ITERATOR_RANGE_H_ +#define TENSORFLOW_LIB_GTL_ITERATOR_RANGE_H_ + +#include + +namespace tensorflow { +namespace gtl { + +// A range adaptor for a pair of iterators. +// +// This just wraps two iterators into a range-compatible interface. Nothing +// fancy at all. +template +class iterator_range { + public: + iterator_range() : begin_iterator_(), end_iterator_() {} + iterator_range(IteratorT begin_iterator, IteratorT end_iterator) + : begin_iterator_(std::move(begin_iterator)), + end_iterator_(std::move(end_iterator)) {} + + IteratorT begin() const { return begin_iterator_; } + IteratorT end() const { return end_iterator_; } + + private: + IteratorT begin_iterator_, end_iterator_; +}; + +// Convenience function for iterating over sub-ranges. +// +// This provides a bit of syntactic sugar to make using sub-ranges +// in for loops a bit easier. Analogous to std::make_pair(). +template +iterator_range make_range(T x, T y) { + return iterator_range(std::move(x), std::move(y)); +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_ITERATOR_RANGE_H_ diff --git a/tensorflow/core/lib/gtl/iterator_range_test.cc b/tensorflow/core/lib/gtl/iterator_range_test.cc new file mode 100644 index 0000000000..328be4ecbc --- /dev/null +++ b/tensorflow/core/lib/gtl/iterator_range_test.cc @@ -0,0 +1,60 @@ +#include "tensorflow/core/lib/gtl/iterator_range.h" + +#include +#include "tensorflow/core/platform/port.h" +#include + +namespace tensorflow { +namespace gtl { +namespace { + +TEST(IteratorRange, WholeVector) { + std::vector v = {2, 3, 5, 7, 11, 13}; + iterator_range::iterator> range(v.begin(), v.end()); + int index = 0; + for (int prime : range) { + ASSERT_LT(index, v.size()); + EXPECT_EQ(v[index], prime); + ++index; + } + EXPECT_EQ(v.size(), index); +} + +TEST(IteratorRange, VectorMakeRange) { + std::vector v = {2, 3, 5, 7, 11, 13}; + auto range = make_range(v.begin(), v.end()); + int index = 0; + for (int prime : range) { + ASSERT_LT(index, v.size()); + EXPECT_EQ(v[index], prime); + ++index; + } + EXPECT_EQ(v.size(), index); +} + +TEST(IteratorRange, PartArray) { + int v[] = {2, 3, 5, 7, 11, 13}; + iterator_range range(&v[1], &v[4]); // 3, 5, 7 + int index = 1; + for (int prime : range) { + ASSERT_LT(index, TF_ARRAYSIZE(v)); + EXPECT_EQ(v[index], prime); + ++index; + } + EXPECT_EQ(4, index); +} + +TEST(IteratorRange, ArrayMakeRange) { + int v[] = {2, 3, 5, 7, 11, 13}; + auto range = make_range(&v[1], &v[4]); // 3, 5, 7 + int index = 1; + for (int prime : range) { + ASSERT_LT(index, TF_ARRAYSIZE(v)); + EXPECT_EQ(v[index], prime); + ++index; + } + EXPECT_EQ(4, index); +} +} // namespace +} // namespace gtl +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/manual_constructor.h b/tensorflow/core/lib/gtl/manual_constructor.h new file mode 100644 index 0000000000..39f029ed4a --- /dev/null +++ b/tensorflow/core/lib/gtl/manual_constructor.h @@ -0,0 +1,230 @@ +// ManualConstructor statically-allocates space in which to store some +// object, but does not initialize it. You can then call the constructor +// and destructor for the object yourself as you see fit. This is useful +// for memory management optimizations, where you want to initialize and +// destroy an object multiple times but only allocate it once. +// +// (When I say ManualConstructor statically allocates space, I mean that +// the ManualConstructor object itself is forced to be the right size.) + +#ifndef TENSORFLOW_LIB_GTL_MANUAL_CONSTRUCTOR_H_ +#define TENSORFLOW_LIB_GTL_MANUAL_CONSTRUCTOR_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/port.h" // For aligned_malloc/aligned_free + +namespace tensorflow { +namespace gtl { +namespace internal { + +// +// Provides a char array with the exact same alignment as another type. The +// first parameter must be a complete type, the second parameter is how many +// of that type to provide space for. +// +// TF_LIB_GTL_ALIGNED_CHAR_ARRAY(struct stat, 16) storage_; +// +// Because MSVC and older GCCs require that the argument to their alignment +// construct to be a literal constant integer, we use a template instantiated +// at all the possible powers of two. +#ifndef SWIG +template +struct AlignType {}; +template +struct AlignType<0, size> { + typedef char result[size]; +}; +#if defined(COMPILER_MSVC) +#define TF_LIB_GTL_ALIGN_ATTRIBUTE(X) __declspec(align(X)) +#define TF_LIB_GTL_ALIGN_OF(T) __alignof(T) +#elif defined(COMPILER_GCC3) || __GNUC__ >= 3 || defined(__APPLE__) || \ + defined(COMPILER_ICC) || defined(OS_NACL) || defined(__clang__) +#define TF_LIB_GTL_ALIGN_ATTRIBUTE(X) __attribute__((aligned(X))) +#define TF_LIB_GTL_ALIGN_OF(T) __alignof__(T) +#endif + +#if defined(TF_LIB_GTL_ALIGN_ATTRIBUTE) + +#define TF_LIB_GTL_ALIGNTYPE_TEMPLATE(X) \ + template \ + struct AlignType { \ + typedef TF_LIB_GTL_ALIGN_ATTRIBUTE(X) char result[size]; \ + } + +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(1); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(2); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(4); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(8); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(16); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(32); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(64); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(128); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(256); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(512); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(1024); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(2048); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(4096); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(8192); +// Any larger and MSVC++ will complain. + +#define TF_LIB_GTL_ALIGNED_CHAR_ARRAY(T, Size) \ + typename tensorflow::gtl::internal::AlignType::result + +#undef TF_LIB_GTL_ALIGNTYPE_TEMPLATE +#undef TF_LIB_GTL_ALIGN_ATTRIBUTE + +#else // defined(TF_LIB_GTL_ALIGN_ATTRIBUTE) +#error "You must define TF_LIB_GTL_ALIGNED_CHAR_ARRAY for your compiler." +#endif // defined(TF_LIB_GTL_ALIGN_ATTRIBUTE) + +#else // !SWIG + +// SWIG can't represent alignment and doesn't care about alignment on data +// members (it works fine without it). +template +struct AlignType { + typedef char result[Size]; +}; +#define TF_LIB_GTL_ALIGNED_CHAR_ARRAY(T, Size) \ + tensorflow::gtl::internal::AlignType::result + +// Enough to parse with SWIG, will never be used by running code. +#define TF_LIB_GTL_ALIGN_OF(Type) 16 + +#endif // !SWIG + +} // namespace internal +} // namespace gtl + +template +class ManualConstructor { + public: + // No constructor or destructor because one of the most useful uses of + // this class is as part of a union, and members of a union cannot have + // constructors or destructors. And, anyway, the whole point of this + // class is to bypass these. + + // Support users creating arrays of ManualConstructor<>s. This ensures that + // the array itself has the correct alignment. + static void* operator new[](size_t size) { + return port::aligned_malloc(size, TF_LIB_GTL_ALIGN_OF(Type)); + } + static void operator delete[](void* mem) { port::aligned_free(mem); } + + inline Type* get() { return reinterpret_cast(space_); } + inline const Type* get() const { + return reinterpret_cast(space_); + } + + inline Type* operator->() { return get(); } + inline const Type* operator->() const { return get(); } + + inline Type& operator*() { return *get(); } + inline const Type& operator*() const { return *get(); } + + inline void Init() { new (space_) Type; } + +// Init() constructs the Type instance using the given arguments +// (which are forwarded to Type's constructor). In C++11, Init() can +// take any number of arguments of any type, and forwards them perfectly. +// On pre-C++11 platforms, it can take up to 11 arguments, and may not be +// able to forward certain kinds of arguments. +// +// Note that Init() with no arguments performs default-initialization, +// not zero-initialization (i.e it behaves the same as "new Type;", not +// "new Type();"), so it will leave non-class types uninitialized. +#ifdef LANG_CXX11 + template + inline void Init(Ts&&... args) { // NOLINT + new (space_) Type(std::forward(args)...); // NOLINT + } +#else // !defined(LANG_CXX11) + template + inline void Init(const T1& p1) { + new (space_) Type(p1); + } + + template + inline void Init(const T1& p1, const T2& p2) { + new (space_) Type(p1, p2); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3) { + new (space_) Type(p1, p2, p3); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4) { + new (space_) Type(p1, p2, p3, p4); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5) { + new (space_) Type(p1, p2, p3, p4, p5); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6) { + new (space_) Type(p1, p2, p3, p4, p5, p6); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7, const T8& p8) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7, const T8& p8, + const T9& p9) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7, const T8& p8, + const T9& p9, const T10& p10) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7, const T8& p8, + const T9& p9, const T10& p10, const T11& p11) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11); + } +#endif // LANG_CXX11 + + inline void Destroy() { get()->~Type(); } + + private: + TF_LIB_GTL_ALIGNED_CHAR_ARRAY(Type, 1) space_; +}; + +#undef TF_LIB_GTL_ALIGNED_CHAR_ARRAY +#undef TF_LIB_GTL_ALIGN_OF + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_MANUAL_CONSTRUCTOR_H_ diff --git a/tensorflow/core/lib/gtl/manual_constructor_test.cc b/tensorflow/core/lib/gtl/manual_constructor_test.cc new file mode 100644 index 0000000000..a929591be2 --- /dev/null +++ b/tensorflow/core/lib/gtl/manual_constructor_test.cc @@ -0,0 +1,113 @@ +#include "tensorflow/core/lib/gtl/manual_constructor.h" + +#include + +#include "tensorflow/core/platform/logging.h" +#include + +namespace tensorflow { +namespace { + +static int constructor_count_ = 0; + +template +struct TestN { + TestN() { ++constructor_count_; } + ~TestN() { --constructor_count_; } + char a[kSize]; +}; + +typedef TestN<1> Test1; +typedef TestN<2> Test2; +typedef TestN<3> Test3; +typedef TestN<4> Test4; +typedef TestN<5> Test5; +typedef TestN<9> Test9; +typedef TestN<15> Test15; + +} // namespace + +namespace { + +TEST(ManualConstructorTest, Sizeof) { + CHECK_EQ(sizeof(ManualConstructor), sizeof(Test1)); + CHECK_EQ(sizeof(ManualConstructor), sizeof(Test2)); + CHECK_EQ(sizeof(ManualConstructor), sizeof(Test3)); + CHECK_EQ(sizeof(ManualConstructor), sizeof(Test4)); + CHECK_EQ(sizeof(ManualConstructor), sizeof(Test5)); + CHECK_EQ(sizeof(ManualConstructor), sizeof(Test9)); + CHECK_EQ(sizeof(ManualConstructor), sizeof(Test15)); + + CHECK_EQ(constructor_count_, 0); + ManualConstructor mt[4]; + CHECK_EQ(sizeof(mt), 4); + CHECK_EQ(constructor_count_, 0); + mt[0].Init(); + CHECK_EQ(constructor_count_, 1); + mt[0].Destroy(); +} + +TEST(ManualConstructorTest, Alignment) { + // We want to make sure that ManualConstructor aligns its memory properly + // on a word barrier. Otherwise, it might be unexpectedly slow, since + // memory access will be unaligned. + + struct { + char a; + ManualConstructor b; + } test1; + struct { + char a; + void* b; + } control1; + + // TODO(bww): Make these tests more direct with C++11 alignment_of::value. + EXPECT_EQ(reinterpret_cast(test1.b.get()) - &test1.a, + reinterpret_cast(&control1.b) - &control1.a); + EXPECT_EQ(reinterpret_cast(test1.b.get()) % sizeof(control1.b), 0); + + struct { + char a; + ManualConstructor b; + } test2; + struct { + char a; + long double b; + } control2; + + EXPECT_EQ(reinterpret_cast(test2.b.get()) - &test2.a, + reinterpret_cast(&control2.b) - &control2.a); +#ifdef ARCH_K8 + EXPECT_EQ(reinterpret_cast(test2.b.get()) % 16, 0); +#endif +#ifdef ARCH_PIII + EXPECT_EQ(reinterpret_cast(test2.b.get()) % 4, 0); +#endif +} + +TEST(ManualConstructorTest, DefaultInitialize) { + struct X { + X() : x(123) {} + int x; + }; + union { + ManualConstructor x; + ManualConstructor y; + } u; + *u.y = -1; + u.x.Init(); // should default-initialize u.x + EXPECT_EQ(123, u.x->x); +} + +TEST(ManualConstructorTest, ZeroInitializePOD) { + union { + ManualConstructor x; + ManualConstructor y; + } u; + *u.y = -1; + u.x.Init(); // should not zero-initialize u.x + EXPECT_EQ(-1, *u.y); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/map_util.h b/tensorflow/core/lib/gtl/map_util.h new file mode 100644 index 0000000000..c953de57c7 --- /dev/null +++ b/tensorflow/core/lib/gtl/map_util.h @@ -0,0 +1,123 @@ +// This file provides utility functions for use with STL map-like data +// structures, such as std::map and hash_map. Some functions will also work with +// sets, such as ContainsKey(). + +#ifndef TENSORFLOW_LIB_GTL_MAP_UTIL_H_ +#define TENSORFLOW_LIB_GTL_MAP_UTIL_H_ + +#include +#include +#include +#include +#include +#include + +namespace tensorflow { +namespace gtl { + +// Returns a pointer to the const value associated with the given key if it +// exists, or NULL otherwise. +template +const typename Collection::value_type::second_type* FindOrNull( + const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return 0; + } + return &it->second; +} + +// Same as above but returns a pointer to the non-const value. +template +typename Collection::value_type::second_type* FindOrNull( + Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + if (it == collection.end()) { + return 0; + } + return &it->second; +} + +// Returns the pointer value associated with the given key. If none is found, +// NULL is returned. The function is designed to be used with a map of keys to +// pointers. +// +// This function does not distinguish between a missing key and a key mapped +// to a NULL value. +template +typename Collection::value_type::second_type FindPtrOrNull( + const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return typename Collection::value_type::second_type(); + } + return it->second; +} + +// Returns a const reference to the value associated with the given key if it +// exists, otherwise returns a const reference to the provided default value. +// +// WARNING: If a temporary object is passed as the default "value," +// this function will return a reference to that temporary object, +// which will be destroyed at the end of the statement. A common +// example: if you have a map with string values, and you pass a char* +// as the default "value," either use the returned value immediately +// or store it in a string (not string&). +template +const typename Collection::value_type::second_type& FindWithDefault( + const Collection& collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return value; + } + return it->second; +} + +// Inserts the given key and value into the given collection if and only if the +// given key did NOT already exist in the collection. If the key previously +// existed in the collection, the value is not changed. Returns true if the +// key-value pair was inserted; returns false if the key was already present. +template +bool InsertIfNotPresent(Collection* const collection, + const typename Collection::value_type& vt) { + return collection->insert(vt).second; +} + +// Same as above except the key and value are passed separately. +template +bool InsertIfNotPresent( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return InsertIfNotPresent(collection, + typename Collection::value_type(key, value)); +} + +// Looks up a given key and value pair in a collection and inserts the key-value +// pair if it's not already present. Returns a reference to the value associated +// with the key. +template +typename Collection::value_type::second_type& LookupOrInsert( + Collection* const collection, const typename Collection::value_type& vt) { + return collection->insert(vt).first->second; +} + +// Same as above except the key-value are passed separately. +template +typename Collection::value_type::second_type& LookupOrInsert( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return LookupOrInsert(collection, + typename Collection::value_type(key, value)); +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_MAP_UTIL_H_ diff --git a/tensorflow/core/lib/gtl/map_util_test.cc b/tensorflow/core/lib/gtl/map_util_test.cc new file mode 100644 index 0000000000..356f987337 --- /dev/null +++ b/tensorflow/core/lib/gtl/map_util_test.cc @@ -0,0 +1,47 @@ +#include "tensorflow/core/lib/gtl/map_util.h" + +#include +#include +#include +#include "tensorflow/core/platform/port.h" + +#include + +namespace tensorflow { + +TEST(MapUtil, Find) { + typedef std::map Map; + Map m; + + // Check that I can use a type that's implicitly convertible to the + // key or value type, such as const char* -> string. + EXPECT_EQ("", gtl::FindWithDefault(m, "foo", "")); + m["foo"] = "bar"; + EXPECT_EQ("bar", gtl::FindWithDefault(m, "foo", "")); + EXPECT_EQ("bar", *gtl::FindOrNull(m, "foo")); + string str; + EXPECT_TRUE(m.count("foo") > 0); + EXPECT_EQ(m["foo"], "bar"); +} + +TEST(MapUtil, LookupOrInsert) { + typedef std::map Map; + Map m; + + // Check that I can use a type that's implicitly convertible to the + // key or value type, such as const char* -> string. + EXPECT_EQ("xyz", gtl::LookupOrInsert(&m, "foo", "xyz")); + EXPECT_EQ("xyz", gtl::LookupOrInsert(&m, "foo", "abc")); +} + +TEST(MapUtil, InsertIfNotPresent) { + // Set operations + typedef std::set Set; + Set s; + EXPECT_TRUE(gtl::InsertIfNotPresent(&s, 0)); + EXPECT_EQ(s.count(0), 1); + EXPECT_FALSE(gtl::InsertIfNotPresent(&s, 0)); + EXPECT_EQ(s.count(0), 1); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/stl_util.h b/tensorflow/core/lib/gtl/stl_util.h new file mode 100644 index 0000000000..83abcd6b55 --- /dev/null +++ b/tensorflow/core/lib/gtl/stl_util.h @@ -0,0 +1,130 @@ +// This file provides utility functions for use with STL + +#ifndef TENSORFLOW_LIB_GTL_STL_UTIL_H_ +#define TENSORFLOW_LIB_GTL_STL_UTIL_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace tensorflow { +namespace gtl { + +// Returns a mutable char* pointing to a string's internal buffer, which may not +// be null-terminated. Returns NULL for an empty string. If not non-null, +// writing through this pointer will modify the string. +// +// string_as_array(&str)[i] is valid for 0 <= i < str.size() until the +// next call to a string method that invalidates iterators. +// +// In C++11 you may simply use &str[0] to get a mutable char*. +// +// Prior to C++11, there was no standard-blessed way of getting a mutable +// reference to a string's internal buffer. The requirement that string be +// contiguous is officially part of the C++11 standard [string.require]/5. +// According to Matt Austern, this should already work on all current C++98 +// implementations. +inline char* string_as_array(string* str) { + return str->empty() ? NULL : &*str->begin(); +} + +// Returns the T* array for the given vector, or NULL if the vector was empty. +// +// Note: If you know the array will never be empty, you can use &*v.begin() +// directly, but that is may dump core if v is empty. This function is the most +// efficient code that will work, taking into account how our STL is actually +// implemented. THIS IS NON-PORTABLE CODE, so use this function instead of +// repeating the nonportable code everywhere. If our STL implementation changes, +// we will need to change this as well. +template +inline T* vector_as_array(std::vector* v) { +#if defined NDEBUG && !defined _GLIBCXX_DEBUG + return &*v->begin(); +#else + return v->empty() ? NULL : &*v->begin(); +#endif +} +// vector_as_array overload for const std::vector<>. +template +inline const T* vector_as_array(const std::vector* v) { +#if defined NDEBUG && !defined _GLIBCXX_DEBUG + return &*v->begin(); +#else + return v->empty() ? NULL : &*v->begin(); +#endif +} + +// Like str->resize(new_size), except any new characters added to "*str" as a +// result of resizing may be left uninitialized, rather than being filled with +// '0' bytes. Typically used when code is then going to overwrite the backing +// store of the string with known data. Uses a Google extension to ::string. +inline void STLStringResizeUninitialized(string* s, size_t new_size) { +#if __google_stl_resize_uninitialized_string + s->resize_uninitialized(new_size); +#else + s->resize(new_size); +#endif +} + +// Calls delete (non-array version) on the SECOND item (pointer) in each pair in +// the range [begin, end). +// +// Note: If you're calling this on an entire container, you probably want to +// call STLDeleteValues(&container) instead, or use ValueDeleter. +template +void STLDeleteContainerPairSecondPointers(ForwardIterator begin, + ForwardIterator end) { + while (begin != end) { + ForwardIterator temp = begin; + ++begin; + delete temp->second; + } +} + +// Deletes all the elements in an STL container and clears the container. This +// function is suitable for use with a vector, set, hash_set, or any other STL +// container which defines sensible begin(), end(), and clear() methods. +// +// If container is NULL, this function is a no-op. +template +void STLDeleteElements(T* container) { + if (!container) return; + auto it = container->begin(); + while (it != container->end()) { + auto temp = it; + ++it; + delete *temp; + } + container->clear(); +} + +// Given an STL container consisting of (key, value) pairs, STLDeleteValues +// deletes all the "value" components and clears the container. Does nothing in +// the case it's given a NULL pointer. +template +void STLDeleteValues(T* container) { + if (!container) return; + auto it = container->begin(); + while (it != container->end()) { + auto temp = it; + ++it; + delete temp->second; + } + container->clear(); +} + +// Sorts and removes duplicates from a sequence container. +template +inline void STLSortAndRemoveDuplicates(T* v) { + std::sort(v->begin(), v->end()); + v->erase(std::unique(v->begin(), v->end()), v->end()); +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_STL_UTIL_H_ diff --git a/tensorflow/core/lib/gtl/top_n.h b/tensorflow/core/lib/gtl/top_n.h new file mode 100644 index 0000000000..b95b998c21 --- /dev/null +++ b/tensorflow/core/lib/gtl/top_n.h @@ -0,0 +1,324 @@ +// This simple class finds the top n elements of an incrementally provided set +// of elements which you push one at a time. If the number of elements exceeds +// n, the lowest elements are incrementally dropped. At the end you get +// a vector of the top elements sorted in descending order (through Extract() or +// ExtractNondestructive()), or a vector of the top elements but not sorted +// (through ExtractUnsorted() or ExtractUnsortedNondestructive()). +// +// The value n is specified in the constructor. If there are p elements pushed +// altogether: +// The total storage requirements are O(min(n, p)) elements +// The running time is O(p * log(min(n, p))) comparisons +// If n is a constant, the total storage required is a constant and the running +// time is linear in p. +// +// NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p) +// runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements, +// discarding the lowest n elements whenever the buffer is full using a linear- +// time median algorithm. This may have better performance when the input +// sequence is partially sorted. +// +// NOTE(zhifengc): This class should be redesigned to avoid reallocating a +// vector for each Extract. + +#ifndef TENSORFLOW_LIB_GTL_TOP_N_H_ +#define TENSORFLOW_LIB_GTL_TOP_N_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace gtl { + +// Cmp is an stl binary predicate. Note that Cmp is the "greater" predicate, +// not the more commonly used "less" predicate. +// +// If you use a "less" predicate here, the TopN will pick out the bottom N +// elements out of the ones passed to it, and it will return them sorted in +// ascending order. +// +// TopN is rule-of-zero copyable and movable if its members are. +template > +class TopN { + public: + // The TopN is in one of the three states: + // + // o UNORDERED: this is the state an instance is originally in, + // where the elements are completely orderless. + // + // o BOTTOM_KNOWN: in this state, we keep the invariant that there + // is at least one element in it, and the lowest element is at + // position 0. The elements in other positions remain + // unsorted. This state is reached if the state was originally + // UNORDERED and a peek_bottom() function call is invoked. + // + // o HEAP_SORTED: in this state, the array is kept as a heap and + // there are exactly (limit_+1) elements in the array. This + // state is reached when at least (limit_+1) elements are + // pushed in. + // + // The state transition graph is at follows: + // + // peek_bottom() (limit_+1) elements + // UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED + // | ^ + // | (limit_+1) elements | + // +-----------------------------------------------------------+ + + enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED }; + using UnsortedIterator = typename std::vector::const_iterator; + + // 'limit' is the maximum number of top results to return. + explicit TopN(size_t limit) : TopN(limit, Cmp()) {} + TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {} + + size_t limit() const { return limit_; } + + // Number of elements currently held by this TopN object. This + // will be no greater than 'limit' passed to the constructor. + size_t size() const { return std::min(elements_.size(), limit_); } + + bool empty() const { return size() == 0; } + + // If you know how many elements you will push at the time you create the + // TopN object, you can call reserve to preallocate the memory that TopN + // will need to process all 'n' pushes. Calling this method is optional. + void reserve(size_t n) { elements_.reserve(std::min(n, limit_ + 1)); } + + // Push 'v'. If the maximum number of elements was exceeded, drop the + // lowest element and return it in 'dropped' (if given). If the maximum is not + // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or + // nullptr, in which case it is not filled in. + // Requires: T is CopyAssignable, Swappable + void push(const T &v) { push(v, nullptr); } + void push(const T &v, T *dropped) { PushInternal(v, dropped); } + + // Move overloads of push. + // Requires: T is MoveAssignable, Swappable + void push(T &&v) { // NOLINT(build/c++11) + push(std::move(v), nullptr); + } + void push(T &&v, T *dropped) { // NOLINT(build/c++11) + PushInternal(std::move(v), dropped); + } + + // Peeks the bottom result without calling Extract() + const T &peek_bottom(); + + // Extract the elements as a vector sorted in descending order. The caller + // assumes ownership of the vector and must delete it when done. This is a + // destructive operation. The only method that can be called immediately + // after Extract() is Reset(). + std::vector *Extract(); + + // Similar to Extract(), but makes no guarantees the elements are in sorted + // order. As with Extract(), the caller assumes ownership of the vector and + // must delete it when done. This is a destructive operation. The only + // method that can be called immediately after ExtractUnsorted() is Reset(). + std::vector *ExtractUnsorted(); + + // A non-destructive version of Extract(). Copy the elements in a new vector + // sorted in descending order and return it. The caller assumes ownership of + // the new vector and must delete it when done. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + std::vector *ExtractNondestructive() const; + + // A non-destructive version of Extract(). Copy the elements to a given + // vector sorted in descending order. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractNondestructive(std::vector *output) const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements in a new + // vector and return it, with no guarantees the elements are in sorted order. + // The caller assumes ownership of the new vector and must delete it when + // done. After calling ExtractUnsortedNondestructive(), the caller can + // continue to push() new elements. + std::vector *ExtractUnsortedNondestructive() const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements into + // a given vector, with no guarantees the elements are in sorted order. + // After calling ExtractUnsortedNondestructive(), the caller can continue + // to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractUnsortedNondestructive(std::vector *output) const; + + // Return an iterator to the beginning (end) of the container, + // with no guarantees about the order of iteration. These iterators are + // invalidated by mutation of the data structure. + UnsortedIterator unsorted_begin() const { return elements_.begin(); } + UnsortedIterator unsorted_end() const { return elements_.begin() + size(); } + + // Accessor for comparator template argument. + Cmp *comparator() { return &cmp_; } + + // This removes all elements. If Extract() or ExtractUnsorted() have been + // called, this will put it back in an empty but useable state. + void Reset(); + + private: + template + void PushInternal(U &&v, T *dropped); // NOLINT(build/c++11) + + // elements_ can be in one of two states: + // elements_.size() <= limit_: elements_ is an unsorted vector of elements + // pushed so far. + // elements_.size() > limit_: The last element of elements_ is unused; + // the other elements of elements_ are an stl heap whose size is exactly + // limit_. In this case elements_.size() is exactly one greater than + // limit_, but don't use "elements_.size() == limit_ + 1" to check for + // that because you'll get a false positive if limit_ == size_t(-1). + std::vector elements_; + size_t limit_; // Maximum number of elements to find + Cmp cmp_; // Greater-than comparison function + State state_ = UNORDERED; +}; + +// ---------------------------------------------------------------------- +// Implementations of non-inline functions + +template +template +void TopN::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11) + if (limit_ == 0) { + if (dropped) *dropped = std::forward(v); // NOLINT(build/c++11) + return; + } + if (state_ != HEAP_SORTED) { + elements_.push_back(std::forward(v)); // NOLINT(build/c++11) + if (state_ == UNORDERED || cmp_(elements_.back(), elements_.front())) { + // Easy case: we just pushed the new element back + } else { + // To maintain the BOTTOM_KNOWN state, we need to make sure that + // the element at position 0 is always the smallest. So we put + // the new element at position 0 and push the original bottom + // element in the back. + // Warning: this code is subtle. + using std::swap; + swap(elements_.front(), elements_.back()); + } + if (elements_.size() == limit_ + 1) { + // Transition from unsorted vector to a heap. + std::make_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.front()); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + state_ = HEAP_SORTED; + } + } else { + // Only insert the new element if it is greater than the least element. + if (cmp_(v, elements_.front())) { + elements_.back() = std::forward(v); // NOLINT(build/c++11) + std::push_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.front()); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + } else { + if (dropped) *dropped = std::forward(v); // NOLINT(build/c++11) + } + } +} + +template +const T &TopN::peek_bottom() { + CHECK(!empty()); + if (state_ == UNORDERED) { + // We need to do a linear scan to find out the bottom element + int min_candidate = 0; + for (size_t i = 1; i < elements_.size(); ++i) { + if (cmp_(elements_[min_candidate], elements_[i])) { + min_candidate = i; + } + } + // By swapping the element at position 0 and the minimal + // element, we transition to the BOTTOM_KNOWN state + if (min_candidate != 0) { + using std::swap; + swap(elements_[0], elements_[min_candidate]); + } + state_ = BOTTOM_KNOWN; + } + return elements_.front(); +} + +template +std::vector *TopN::Extract() { + auto out = new std::vector; + out->swap(elements_); + if (state_ != HEAP_SORTED) { + std::sort(out->begin(), out->end(), cmp_); + } else { + out->pop_back(); + std::sort_heap(out->begin(), out->end(), cmp_); + } + return out; +} + +template +std::vector *TopN::ExtractUnsorted() { + auto out = new std::vector; + out->swap(elements_); + if (state_ == HEAP_SORTED) { + // Remove the limit_+1'th element. + out->pop_back(); + } + return out; +} + +template +std::vector *TopN::ExtractNondestructive() const { + auto out = new std::vector; + ExtractNondestructive(out); + return out; +} + +template +void TopN::ExtractNondestructive(std::vector *output) const { + CHECK(output); + *output = elements_; + if (state_ != HEAP_SORTED) { + std::sort(output->begin(), output->end(), cmp_); + } else { + output->pop_back(); + std::sort_heap(output->begin(), output->end(), cmp_); + } +} + +template +std::vector *TopN::ExtractUnsortedNondestructive() const { + auto elements = new std::vector; + ExtractUnsortedNondestructive(elements); + return elements; +} + +template +void TopN::ExtractUnsortedNondestructive(std::vector *output) const { + CHECK(output); + *output = elements_; + if (state_ == HEAP_SORTED) { + // Remove the limit_+1'th element. + output->pop_back(); + } +} + +template +void TopN::Reset() { + elements_.clear(); + state_ = UNORDERED; +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_TOP_N_H_ diff --git a/tensorflow/core/lib/gtl/top_n_test.cc b/tensorflow/core/lib/gtl/top_n_test.cc new file mode 100644 index 0000000000..1812a1bd3f --- /dev/null +++ b/tensorflow/core/lib/gtl/top_n_test.cc @@ -0,0 +1,249 @@ +// Unit test for TopN. + +#include "tensorflow/core/lib/gtl/top_n.h" + +#include +#include + +#include +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace { + +using tensorflow::gtl::TopN; +using tensorflow::random::PhiloxRandom; +using tensorflow::random::SimplePhilox; +using tensorflow::string; + +// Move the contents from an owned raw pointer, returning by value. +// Objects are easier to manage by value. +template +T ConsumeRawPtr(T *p) { + T tmp = std::move(*p); + delete p; + return tmp; +} + +template +void TestIntTopNHelper(size_t limit, size_t n_elements, const Cmp &cmp, + SimplePhilox *random, bool test_peek, + bool test_extract_unsorted) { + LOG(INFO) << "Testing limit=" << limit << ", n_elements=" << n_elements + << ", test_peek=" << test_peek + << ", test_extract_unsorted=" << test_extract_unsorted; + TopN top(limit, cmp); + std::vector shadow(n_elements); + for (int i = 0; i != n_elements; ++i) shadow[i] = random->Uniform(limit); + for (int e : shadow) top.push(e); + std::sort(shadow.begin(), shadow.end(), cmp); + size_t top_size = std::min(limit, n_elements); + EXPECT_EQ(top_size, top.size()); + if (test_peek && top_size != 0) { + EXPECT_EQ(shadow[top_size - 1], top.peek_bottom()); + } + std::vector v; + if (test_extract_unsorted) { + v = ConsumeRawPtr(top.ExtractUnsorted()); + std::sort(v.begin(), v.end(), cmp); + } else { + v = ConsumeRawPtr(top.Extract()); + } + EXPECT_EQ(top_size, v.size()); + for (int i = 0; i != top_size; ++i) { + VLOG(1) << "Top element " << v[i]; + EXPECT_EQ(shadow[i], v[i]); + } +} + +template +void TestIntTopN(size_t limit, size_t n_elements, const Cmp &cmp, + SimplePhilox *random) { + // Test peek_bottom() and Extract() + TestIntTopNHelper(limit, n_elements, cmp, random, true, false); + // Test Extract() + TestIntTopNHelper(limit, n_elements, cmp, random, false, false); + // Test peek_bottom() and ExtractUnsorted() + TestIntTopNHelper(limit, n_elements, cmp, random, true, true); + // Test ExtractUnsorted() + TestIntTopNHelper(limit, n_elements, cmp, random, false, true); +} + +TEST(TopNTest, Misc) { + PhiloxRandom philox(1, 1); + SimplePhilox random(&philox); + + TestIntTopN(0, 5, std::greater(), &random); + TestIntTopN(32, 0, std::greater(), &random); + TestIntTopN(6, 6, std::greater(), &random); + TestIntTopN(6, 6, std::less(), &random); + TestIntTopN(1000, 999, std::greater(), &random); + TestIntTopN(1000, 1000, std::greater(), &random); + TestIntTopN(1000, 1001, std::greater(), &random); + TestIntTopN(2300, 28393, std::less(), &random); + TestIntTopN(30, 100, std::greater(), &random); + TestIntTopN(100, 30, std::less(), &random); + TestIntTopN(size_t(-1), 3, std::greater(), &random); + TestIntTopN(size_t(-1), 0, std::greater(), &random); + TestIntTopN(0, 5, std::greater(), &random); +} + +TEST(TopNTest, String) { + LOG(INFO) << "Testing strings"; + + TopN top(3); + EXPECT_TRUE(top.empty()); + top.push("abracadabra"); + top.push("waldemar"); + EXPECT_EQ(2, top.size()); + EXPECT_EQ("abracadabra", top.peek_bottom()); + top.push(""); + EXPECT_EQ(3, top.size()); + EXPECT_EQ("", top.peek_bottom()); + top.push("top"); + EXPECT_EQ(3, top.size()); + EXPECT_EQ("abracadabra", top.peek_bottom()); + top.push("Google"); + top.push("test"); + EXPECT_EQ(3, top.size()); + EXPECT_EQ("test", top.peek_bottom()); + TopN top2(top); + TopN top3(5); + top3 = top; + EXPECT_EQ("test", top3.peek_bottom()); + { + std::vector s = ConsumeRawPtr(top.Extract()); + EXPECT_EQ(s[0], "waldemar"); + EXPECT_EQ(s[1], "top"); + EXPECT_EQ(s[2], "test"); + } + + top2.push("zero"); + EXPECT_EQ(top2.peek_bottom(), "top"); + + { + std::vector s = ConsumeRawPtr(top2.Extract()); + EXPECT_EQ(s[0], "zero"); + EXPECT_EQ(s[1], "waldemar"); + EXPECT_EQ(s[2], "top"); + } + { + std::vector s = ConsumeRawPtr(top3.Extract()); + EXPECT_EQ(s[0], "waldemar"); + EXPECT_EQ(s[1], "top"); + EXPECT_EQ(s[2], "test"); + } + + TopN top4(3); + // Run this test twice to check Reset(): + for (int i = 0; i < 2; ++i) { + top4.push("abcd"); + top4.push("ijkl"); + top4.push("efgh"); + top4.push("mnop"); + std::vector s = ConsumeRawPtr(top4.Extract()); + EXPECT_EQ(s[0], "mnop"); + EXPECT_EQ(s[1], "ijkl"); + EXPECT_EQ(s[2], "efgh"); + top4.Reset(); + } +} + +// Test that pointers aren't leaked from a TopN if we use the 2-argument version +// of push(). +TEST(TopNTest, Ptr) { + LOG(INFO) << "Testing 2-argument push()"; + TopN topn(3); + for (int i = 0; i < 8; ++i) { + string *dropped = NULL; + topn.push(new string(std::to_string(i)), &dropped); + delete dropped; + } + + for (int i = 8; i > 0; --i) { + string *dropped = NULL; + topn.push(new string(std::to_string(i)), &dropped); + delete dropped; + } + + std::vector extract = ConsumeRawPtr(topn.Extract()); + tensorflow::gtl::STLDeleteElements(&extract); +} + +struct PointeeGreater { + template + bool operator()(const T &a, const T &b) const { + return *a > *b; + } +}; + +TEST(TopNTest, MoveOnly) { + using StrPtr = std::unique_ptr; + TopN topn(3); + for (int i = 0; i < 8; ++i) topn.push(StrPtr(new string(std::to_string(i)))); + for (int i = 8; i > 0; --i) topn.push(StrPtr(new string(std::to_string(i)))); + + std::vector extract = ConsumeRawPtr(topn.Extract()); + EXPECT_EQ(extract.size(), 3); + EXPECT_EQ(*(extract[0]), "8"); + EXPECT_EQ(*(extract[1]), "7"); + EXPECT_EQ(*(extract[2]), "7"); +} + +// Test that Nondestructive extracts do not need a Reset() afterwards, +// and that pointers aren't leaked from a TopN after calling them. +TEST(TopNTest, Nondestructive) { + LOG(INFO) << "Testing Nondestructive extracts"; + TopN top4(4); + for (int i = 0; i < 8; ++i) { + top4.push(i); + std::vector v = ConsumeRawPtr(top4.ExtractNondestructive()); + EXPECT_EQ(std::min(i + 1, 4), v.size()); + for (size_t j = 0; j < v.size(); ++j) EXPECT_EQ(i - j, v[j]); + } + + TopN top3(3); + for (int i = 0; i < 8; ++i) { + top3.push(i); + std::vector v = ConsumeRawPtr(top3.ExtractUnsortedNondestructive()); + std::sort(v.begin(), v.end(), std::greater()); + EXPECT_EQ(std::min(i + 1, 3), v.size()); + for (size_t j = 0; j < v.size(); ++j) EXPECT_EQ(i - j, v[j]); + } +} + +struct ForbiddenCmp { + bool operator()(int lhs, int rhs) const { + LOG(FATAL) << "ForbiddenCmp called " << lhs << " " << rhs; + } +}; + +TEST(TopNTest, ZeroLimit) { + TopN top(0); + top.push(1); + top.push(2); + + int dropped = -1; + top.push(1, &dropped); + top.push(2, &dropped); + + std::vector v; + top.ExtractNondestructive(&v); + EXPECT_EQ(0, v.size()); +} + +TEST(TopNTest, Iteration) { + TopN top(4); + for (int i = 0; i < 8; ++i) top.push(i); + std::vector actual(top.unsorted_begin(), top.unsorted_end()); + // Check that we have 4,5,6,7 as the top 4 (in some order, so we sort) + sort(actual.begin(), actual.end()); + EXPECT_EQ(actual.size(), 4); + EXPECT_EQ(actual[0], 4); + EXPECT_EQ(actual[1], 5); + EXPECT_EQ(actual[2], 6); + EXPECT_EQ(actual[3], 7); +} +} // namespace diff --git a/tensorflow/core/lib/hash/crc32c.cc b/tensorflow/core/lib/hash/crc32c.cc new file mode 100644 index 0000000000..3bef1cf78d --- /dev/null +++ b/tensorflow/core/lib/hash/crc32c.cc @@ -0,0 +1,244 @@ +// A portable implementation of crc32c, optimized to handle +// four bytes at a time. + +#include "tensorflow/core/lib/hash/crc32c.h" + +#include +#include "tensorflow/core/lib/core/coding.h" + +namespace tensorflow { +namespace crc32c { + +static const uint32 table0_[256] = { + 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 0xc79a971f, 0x35f1141c, + 0x26a1e7e8, 0xd4ca64eb, 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, + 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, 0x105ec76f, 0xe235446c, + 0xf165b798, 0x030e349b, 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, + 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 0x5d1d08bf, 0xaf768bbc, + 0xbc267848, 0x4e4dfb4b, 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, + 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, 0xaa64d611, 0x580f5512, + 0x4b5fa6e6, 0xb93425e5, 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, + 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 0xf779deae, 0x05125dad, + 0x1642ae59, 0xe4292d5a, 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, + 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, 0x417b1dbc, 0xb3109ebf, + 0xa0406d4b, 0x522bee48, 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, + 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 0x0c38d26c, 0xfe53516f, + 0xed03a29b, 0x1f682198, 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, + 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, 0xdbfc821c, 0x2997011f, + 0x3ac7f2eb, 0xc8ac71e8, 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, + 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 0xa65c047d, 0x5437877e, + 0x4767748a, 0xb50cf789, 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, + 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, 0x7198540d, 0x83f3d70e, + 0x90a324fa, 0x62c8a7f9, 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, + 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 0x3cdb9bdd, 0xceb018de, + 0xdde0eb2a, 0x2f8b6829, 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, + 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, 0x082f63b7, 0xfa44e0b4, + 0xe9141340, 0x1b7f9043, 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, + 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 0x55326b08, 0xa759e80b, + 0xb4091bff, 0x466298fc, 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, + 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, 0xa24bb5a6, 0x502036a5, + 0x4370c551, 0xb11b4652, 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, + 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 0xef087a76, 0x1d63f975, + 0x0e330a81, 0xfc588982, 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, + 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, 0x38cc2a06, 0xcaa7a905, + 0xd9f75af1, 0x2b9cd9f2, 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, + 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 0x0417b1db, 0xf67c32d8, + 0xe52cc12c, 0x1747422f, 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, + 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, 0xd3d3e1ab, 0x21b862a8, + 0x32e8915c, 0xc083125f, 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, + 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 0x9e902e7b, 0x6cfbad78, + 0x7fab5e8c, 0x8dc0dd8f, 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, + 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, 0x69e9f0d5, 0x9b8273d6, + 0x88d28022, 0x7ab90321, 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, + 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 0x34f4f86a, 0xc69f7b69, + 0xd5cf889d, 0x27a40b9e, 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, + 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351}; +static const uint32 table1_[256] = { + 0x00000000, 0x13a29877, 0x274530ee, 0x34e7a899, 0x4e8a61dc, 0x5d28f9ab, + 0x69cf5132, 0x7a6dc945, 0x9d14c3b8, 0x8eb65bcf, 0xba51f356, 0xa9f36b21, + 0xd39ea264, 0xc03c3a13, 0xf4db928a, 0xe7790afd, 0x3fc5f181, 0x2c6769f6, + 0x1880c16f, 0x0b225918, 0x714f905d, 0x62ed082a, 0x560aa0b3, 0x45a838c4, + 0xa2d13239, 0xb173aa4e, 0x859402d7, 0x96369aa0, 0xec5b53e5, 0xfff9cb92, + 0xcb1e630b, 0xd8bcfb7c, 0x7f8be302, 0x6c297b75, 0x58ced3ec, 0x4b6c4b9b, + 0x310182de, 0x22a31aa9, 0x1644b230, 0x05e62a47, 0xe29f20ba, 0xf13db8cd, + 0xc5da1054, 0xd6788823, 0xac154166, 0xbfb7d911, 0x8b507188, 0x98f2e9ff, + 0x404e1283, 0x53ec8af4, 0x670b226d, 0x74a9ba1a, 0x0ec4735f, 0x1d66eb28, + 0x298143b1, 0x3a23dbc6, 0xdd5ad13b, 0xcef8494c, 0xfa1fe1d5, 0xe9bd79a2, + 0x93d0b0e7, 0x80722890, 0xb4958009, 0xa737187e, 0xff17c604, 0xecb55e73, + 0xd852f6ea, 0xcbf06e9d, 0xb19da7d8, 0xa23f3faf, 0x96d89736, 0x857a0f41, + 0x620305bc, 0x71a19dcb, 0x45463552, 0x56e4ad25, 0x2c896460, 0x3f2bfc17, + 0x0bcc548e, 0x186eccf9, 0xc0d23785, 0xd370aff2, 0xe797076b, 0xf4359f1c, + 0x8e585659, 0x9dface2e, 0xa91d66b7, 0xbabffec0, 0x5dc6f43d, 0x4e646c4a, + 0x7a83c4d3, 0x69215ca4, 0x134c95e1, 0x00ee0d96, 0x3409a50f, 0x27ab3d78, + 0x809c2506, 0x933ebd71, 0xa7d915e8, 0xb47b8d9f, 0xce1644da, 0xddb4dcad, + 0xe9537434, 0xfaf1ec43, 0x1d88e6be, 0x0e2a7ec9, 0x3acdd650, 0x296f4e27, + 0x53028762, 0x40a01f15, 0x7447b78c, 0x67e52ffb, 0xbf59d487, 0xacfb4cf0, + 0x981ce469, 0x8bbe7c1e, 0xf1d3b55b, 0xe2712d2c, 0xd69685b5, 0xc5341dc2, + 0x224d173f, 0x31ef8f48, 0x050827d1, 0x16aabfa6, 0x6cc776e3, 0x7f65ee94, + 0x4b82460d, 0x5820de7a, 0xfbc3faf9, 0xe861628e, 0xdc86ca17, 0xcf245260, + 0xb5499b25, 0xa6eb0352, 0x920cabcb, 0x81ae33bc, 0x66d73941, 0x7575a136, + 0x419209af, 0x523091d8, 0x285d589d, 0x3bffc0ea, 0x0f186873, 0x1cbaf004, + 0xc4060b78, 0xd7a4930f, 0xe3433b96, 0xf0e1a3e1, 0x8a8c6aa4, 0x992ef2d3, + 0xadc95a4a, 0xbe6bc23d, 0x5912c8c0, 0x4ab050b7, 0x7e57f82e, 0x6df56059, + 0x1798a91c, 0x043a316b, 0x30dd99f2, 0x237f0185, 0x844819fb, 0x97ea818c, + 0xa30d2915, 0xb0afb162, 0xcac27827, 0xd960e050, 0xed8748c9, 0xfe25d0be, + 0x195cda43, 0x0afe4234, 0x3e19eaad, 0x2dbb72da, 0x57d6bb9f, 0x447423e8, + 0x70938b71, 0x63311306, 0xbb8de87a, 0xa82f700d, 0x9cc8d894, 0x8f6a40e3, + 0xf50789a6, 0xe6a511d1, 0xd242b948, 0xc1e0213f, 0x26992bc2, 0x353bb3b5, + 0x01dc1b2c, 0x127e835b, 0x68134a1e, 0x7bb1d269, 0x4f567af0, 0x5cf4e287, + 0x04d43cfd, 0x1776a48a, 0x23910c13, 0x30339464, 0x4a5e5d21, 0x59fcc556, + 0x6d1b6dcf, 0x7eb9f5b8, 0x99c0ff45, 0x8a626732, 0xbe85cfab, 0xad2757dc, + 0xd74a9e99, 0xc4e806ee, 0xf00fae77, 0xe3ad3600, 0x3b11cd7c, 0x28b3550b, + 0x1c54fd92, 0x0ff665e5, 0x759baca0, 0x663934d7, 0x52de9c4e, 0x417c0439, + 0xa6050ec4, 0xb5a796b3, 0x81403e2a, 0x92e2a65d, 0xe88f6f18, 0xfb2df76f, + 0xcfca5ff6, 0xdc68c781, 0x7b5fdfff, 0x68fd4788, 0x5c1aef11, 0x4fb87766, + 0x35d5be23, 0x26772654, 0x12908ecd, 0x013216ba, 0xe64b1c47, 0xf5e98430, + 0xc10e2ca9, 0xd2acb4de, 0xa8c17d9b, 0xbb63e5ec, 0x8f844d75, 0x9c26d502, + 0x449a2e7e, 0x5738b609, 0x63df1e90, 0x707d86e7, 0x0a104fa2, 0x19b2d7d5, + 0x2d557f4c, 0x3ef7e73b, 0xd98eedc6, 0xca2c75b1, 0xfecbdd28, 0xed69455f, + 0x97048c1a, 0x84a6146d, 0xb041bcf4, 0xa3e32483}; +static const uint32 table2_[256] = { + 0x00000000, 0xa541927e, 0x4f6f520d, 0xea2ec073, 0x9edea41a, 0x3b9f3664, + 0xd1b1f617, 0x74f06469, 0x38513ec5, 0x9d10acbb, 0x773e6cc8, 0xd27ffeb6, + 0xa68f9adf, 0x03ce08a1, 0xe9e0c8d2, 0x4ca15aac, 0x70a27d8a, 0xd5e3eff4, + 0x3fcd2f87, 0x9a8cbdf9, 0xee7cd990, 0x4b3d4bee, 0xa1138b9d, 0x045219e3, + 0x48f3434f, 0xedb2d131, 0x079c1142, 0xa2dd833c, 0xd62de755, 0x736c752b, + 0x9942b558, 0x3c032726, 0xe144fb14, 0x4405696a, 0xae2ba919, 0x0b6a3b67, + 0x7f9a5f0e, 0xdadbcd70, 0x30f50d03, 0x95b49f7d, 0xd915c5d1, 0x7c5457af, + 0x967a97dc, 0x333b05a2, 0x47cb61cb, 0xe28af3b5, 0x08a433c6, 0xade5a1b8, + 0x91e6869e, 0x34a714e0, 0xde89d493, 0x7bc846ed, 0x0f382284, 0xaa79b0fa, + 0x40577089, 0xe516e2f7, 0xa9b7b85b, 0x0cf62a25, 0xe6d8ea56, 0x43997828, + 0x37691c41, 0x92288e3f, 0x78064e4c, 0xdd47dc32, 0xc76580d9, 0x622412a7, + 0x880ad2d4, 0x2d4b40aa, 0x59bb24c3, 0xfcfab6bd, 0x16d476ce, 0xb395e4b0, + 0xff34be1c, 0x5a752c62, 0xb05bec11, 0x151a7e6f, 0x61ea1a06, 0xc4ab8878, + 0x2e85480b, 0x8bc4da75, 0xb7c7fd53, 0x12866f2d, 0xf8a8af5e, 0x5de93d20, + 0x29195949, 0x8c58cb37, 0x66760b44, 0xc337993a, 0x8f96c396, 0x2ad751e8, + 0xc0f9919b, 0x65b803e5, 0x1148678c, 0xb409f5f2, 0x5e273581, 0xfb66a7ff, + 0x26217bcd, 0x8360e9b3, 0x694e29c0, 0xcc0fbbbe, 0xb8ffdfd7, 0x1dbe4da9, + 0xf7908dda, 0x52d11fa4, 0x1e704508, 0xbb31d776, 0x511f1705, 0xf45e857b, + 0x80aee112, 0x25ef736c, 0xcfc1b31f, 0x6a802161, 0x56830647, 0xf3c29439, + 0x19ec544a, 0xbcadc634, 0xc85da25d, 0x6d1c3023, 0x8732f050, 0x2273622e, + 0x6ed23882, 0xcb93aafc, 0x21bd6a8f, 0x84fcf8f1, 0xf00c9c98, 0x554d0ee6, + 0xbf63ce95, 0x1a225ceb, 0x8b277743, 0x2e66e53d, 0xc448254e, 0x6109b730, + 0x15f9d359, 0xb0b84127, 0x5a968154, 0xffd7132a, 0xb3764986, 0x1637dbf8, + 0xfc191b8b, 0x595889f5, 0x2da8ed9c, 0x88e97fe2, 0x62c7bf91, 0xc7862def, + 0xfb850ac9, 0x5ec498b7, 0xb4ea58c4, 0x11abcaba, 0x655baed3, 0xc01a3cad, + 0x2a34fcde, 0x8f756ea0, 0xc3d4340c, 0x6695a672, 0x8cbb6601, 0x29faf47f, + 0x5d0a9016, 0xf84b0268, 0x1265c21b, 0xb7245065, 0x6a638c57, 0xcf221e29, + 0x250cde5a, 0x804d4c24, 0xf4bd284d, 0x51fcba33, 0xbbd27a40, 0x1e93e83e, + 0x5232b292, 0xf77320ec, 0x1d5de09f, 0xb81c72e1, 0xccec1688, 0x69ad84f6, + 0x83834485, 0x26c2d6fb, 0x1ac1f1dd, 0xbf8063a3, 0x55aea3d0, 0xf0ef31ae, + 0x841f55c7, 0x215ec7b9, 0xcb7007ca, 0x6e3195b4, 0x2290cf18, 0x87d15d66, + 0x6dff9d15, 0xc8be0f6b, 0xbc4e6b02, 0x190ff97c, 0xf321390f, 0x5660ab71, + 0x4c42f79a, 0xe90365e4, 0x032da597, 0xa66c37e9, 0xd29c5380, 0x77ddc1fe, + 0x9df3018d, 0x38b293f3, 0x7413c95f, 0xd1525b21, 0x3b7c9b52, 0x9e3d092c, + 0xeacd6d45, 0x4f8cff3b, 0xa5a23f48, 0x00e3ad36, 0x3ce08a10, 0x99a1186e, + 0x738fd81d, 0xd6ce4a63, 0xa23e2e0a, 0x077fbc74, 0xed517c07, 0x4810ee79, + 0x04b1b4d5, 0xa1f026ab, 0x4bdee6d8, 0xee9f74a6, 0x9a6f10cf, 0x3f2e82b1, + 0xd50042c2, 0x7041d0bc, 0xad060c8e, 0x08479ef0, 0xe2695e83, 0x4728ccfd, + 0x33d8a894, 0x96993aea, 0x7cb7fa99, 0xd9f668e7, 0x9557324b, 0x3016a035, + 0xda386046, 0x7f79f238, 0x0b899651, 0xaec8042f, 0x44e6c45c, 0xe1a75622, + 0xdda47104, 0x78e5e37a, 0x92cb2309, 0x378ab177, 0x437ad51e, 0xe63b4760, + 0x0c158713, 0xa954156d, 0xe5f54fc1, 0x40b4ddbf, 0xaa9a1dcc, 0x0fdb8fb2, + 0x7b2bebdb, 0xde6a79a5, 0x3444b9d6, 0x91052ba8}; +static const uint32 table3_[256] = { + 0x00000000, 0xdd45aab8, 0xbf672381, 0x62228939, 0x7b2231f3, 0xa6679b4b, + 0xc4451272, 0x1900b8ca, 0xf64463e6, 0x2b01c95e, 0x49234067, 0x9466eadf, + 0x8d665215, 0x5023f8ad, 0x32017194, 0xef44db2c, 0xe964b13d, 0x34211b85, + 0x560392bc, 0x8b463804, 0x924680ce, 0x4f032a76, 0x2d21a34f, 0xf06409f7, + 0x1f20d2db, 0xc2657863, 0xa047f15a, 0x7d025be2, 0x6402e328, 0xb9474990, + 0xdb65c0a9, 0x06206a11, 0xd725148b, 0x0a60be33, 0x6842370a, 0xb5079db2, + 0xac072578, 0x71428fc0, 0x136006f9, 0xce25ac41, 0x2161776d, 0xfc24ddd5, + 0x9e0654ec, 0x4343fe54, 0x5a43469e, 0x8706ec26, 0xe524651f, 0x3861cfa7, + 0x3e41a5b6, 0xe3040f0e, 0x81268637, 0x5c632c8f, 0x45639445, 0x98263efd, + 0xfa04b7c4, 0x27411d7c, 0xc805c650, 0x15406ce8, 0x7762e5d1, 0xaa274f69, + 0xb327f7a3, 0x6e625d1b, 0x0c40d422, 0xd1057e9a, 0xaba65fe7, 0x76e3f55f, + 0x14c17c66, 0xc984d6de, 0xd0846e14, 0x0dc1c4ac, 0x6fe34d95, 0xb2a6e72d, + 0x5de23c01, 0x80a796b9, 0xe2851f80, 0x3fc0b538, 0x26c00df2, 0xfb85a74a, + 0x99a72e73, 0x44e284cb, 0x42c2eeda, 0x9f874462, 0xfda5cd5b, 0x20e067e3, + 0x39e0df29, 0xe4a57591, 0x8687fca8, 0x5bc25610, 0xb4868d3c, 0x69c32784, + 0x0be1aebd, 0xd6a40405, 0xcfa4bccf, 0x12e11677, 0x70c39f4e, 0xad8635f6, + 0x7c834b6c, 0xa1c6e1d4, 0xc3e468ed, 0x1ea1c255, 0x07a17a9f, 0xdae4d027, + 0xb8c6591e, 0x6583f3a6, 0x8ac7288a, 0x57828232, 0x35a00b0b, 0xe8e5a1b3, + 0xf1e51979, 0x2ca0b3c1, 0x4e823af8, 0x93c79040, 0x95e7fa51, 0x48a250e9, + 0x2a80d9d0, 0xf7c57368, 0xeec5cba2, 0x3380611a, 0x51a2e823, 0x8ce7429b, + 0x63a399b7, 0xbee6330f, 0xdcc4ba36, 0x0181108e, 0x1881a844, 0xc5c402fc, + 0xa7e68bc5, 0x7aa3217d, 0x52a0c93f, 0x8fe56387, 0xedc7eabe, 0x30824006, + 0x2982f8cc, 0xf4c75274, 0x96e5db4d, 0x4ba071f5, 0xa4e4aad9, 0x79a10061, + 0x1b838958, 0xc6c623e0, 0xdfc69b2a, 0x02833192, 0x60a1b8ab, 0xbde41213, + 0xbbc47802, 0x6681d2ba, 0x04a35b83, 0xd9e6f13b, 0xc0e649f1, 0x1da3e349, + 0x7f816a70, 0xa2c4c0c8, 0x4d801be4, 0x90c5b15c, 0xf2e73865, 0x2fa292dd, + 0x36a22a17, 0xebe780af, 0x89c50996, 0x5480a32e, 0x8585ddb4, 0x58c0770c, + 0x3ae2fe35, 0xe7a7548d, 0xfea7ec47, 0x23e246ff, 0x41c0cfc6, 0x9c85657e, + 0x73c1be52, 0xae8414ea, 0xcca69dd3, 0x11e3376b, 0x08e38fa1, 0xd5a62519, + 0xb784ac20, 0x6ac10698, 0x6ce16c89, 0xb1a4c631, 0xd3864f08, 0x0ec3e5b0, + 0x17c35d7a, 0xca86f7c2, 0xa8a47efb, 0x75e1d443, 0x9aa50f6f, 0x47e0a5d7, + 0x25c22cee, 0xf8878656, 0xe1873e9c, 0x3cc29424, 0x5ee01d1d, 0x83a5b7a5, + 0xf90696d8, 0x24433c60, 0x4661b559, 0x9b241fe1, 0x8224a72b, 0x5f610d93, + 0x3d4384aa, 0xe0062e12, 0x0f42f53e, 0xd2075f86, 0xb025d6bf, 0x6d607c07, + 0x7460c4cd, 0xa9256e75, 0xcb07e74c, 0x16424df4, 0x106227e5, 0xcd278d5d, + 0xaf050464, 0x7240aedc, 0x6b401616, 0xb605bcae, 0xd4273597, 0x09629f2f, + 0xe6264403, 0x3b63eebb, 0x59416782, 0x8404cd3a, 0x9d0475f0, 0x4041df48, + 0x22635671, 0xff26fcc9, 0x2e238253, 0xf36628eb, 0x9144a1d2, 0x4c010b6a, + 0x5501b3a0, 0x88441918, 0xea669021, 0x37233a99, 0xd867e1b5, 0x05224b0d, + 0x6700c234, 0xba45688c, 0xa345d046, 0x7e007afe, 0x1c22f3c7, 0xc167597f, + 0xc747336e, 0x1a0299d6, 0x782010ef, 0xa565ba57, 0xbc65029d, 0x6120a825, + 0x0302211c, 0xde478ba4, 0x31035088, 0xec46fa30, 0x8e647309, 0x5321d9b1, + 0x4a21617b, 0x9764cbc3, 0xf54642fa, 0x2803e842}; + +// Used to fetch a naturally-aligned 32-bit word in little endian byte-order +static inline uint32_t LE_LOAD32(const uint8_t *p) { + return core::DecodeFixed32(reinterpret_cast(p)); +} + +uint32 Extend(uint32 crc, const char *buf, size_t size) { + const uint8 *p = reinterpret_cast(buf); + const uint8 *e = p + size; + uint32 l = crc ^ 0xffffffffu; + +#define STEP1 \ + do { \ + int c = (l & 0xff) ^ *p++; \ + l = table0_[c] ^ (l >> 8); \ + } while (0) + +#define STEP4 \ + do { \ + uint32 c = l ^ LE_LOAD32(p); \ + p += 4; \ + l = table3_[c & 0xff] ^ table2_[(c >> 8) & 0xff] ^ \ + table1_[(c >> 16) & 0xff] ^ table0_[c >> 24]; \ + } while (0) + + // Point x at first 4-byte aligned byte in string. This might be + // just past the end of the string. + const uintptr_t pval = reinterpret_cast(p); + const uint8 *x = reinterpret_cast(((pval + 3) >> 2) << 2); + if (x <= e) { + // Process bytes until finished or p is 4-byte aligned + while (p != x) { + STEP1; + } + } + // Process bytes 16 at a time + while ((e - p) >= 16) { + STEP4; + STEP4; + STEP4; + STEP4; + } + // Process bytes 4 at a time + while ((e - p) >= 4) { + STEP4; + } + // Process the last few bytes + while (p != e) { + STEP1; + } +#undef STEP4 +#undef STEP1 + return l ^ 0xffffffffu; +} + +} // namespace crc32c +} // namespace tensorflow diff --git a/tensorflow/core/lib/hash/crc32c.h b/tensorflow/core/lib/hash/crc32c.h new file mode 100644 index 0000000000..f728b6f5e7 --- /dev/null +++ b/tensorflow/core/lib/hash/crc32c.h @@ -0,0 +1,39 @@ +#ifndef TENSORFLOW_LIB_HASH_CRC32C_H_ +#define TENSORFLOW_LIB_HASH_CRC32C_H_ + +#include +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace crc32c { + +// Return the crc32c of concat(A, data[0,n-1]) where init_crc is the +// crc32c of some string A. Extend() is often used to maintain the +// crc32c of a stream of data. +extern uint32 Extend(uint32 init_crc, const char* data, size_t n); + +// Return the crc32c of data[0,n-1] +inline uint32 Value(const char* data, size_t n) { return Extend(0, data, n); } + +static const uint32 kMaskDelta = 0xa282ead8ul; + +// Return a masked representation of crc. +// +// Motivation: it is problematic to compute the CRC of a string that +// contains embedded CRCs. Therefore we recommend that CRCs stored +// somewhere (e.g., in files) should be masked before being stored. +inline uint32 Mask(uint32 crc) { + // Rotate right by 15 bits and add a constant. + return ((crc >> 15) | (crc << 17)) + kMaskDelta; +} + +// Return the crc whose masked representation is masked_crc. +inline uint32 Unmask(uint32 masked_crc) { + uint32 rot = masked_crc - kMaskDelta; + return ((rot >> 17) | (rot << 15)); +} + +} // namespace crc32c +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_HASH_CRC32C_H_ diff --git a/tensorflow/core/lib/hash/crc32c_test.cc b/tensorflow/core/lib/hash/crc32c_test.cc new file mode 100644 index 0000000000..54aced3186 --- /dev/null +++ b/tensorflow/core/lib/hash/crc32c_test.cc @@ -0,0 +1,51 @@ +#include "tensorflow/core/lib/hash/crc32c.h" +#include + +namespace tensorflow { +namespace crc32c { + +TEST(CRC, StandardResults) { + // From rfc3720 section B.4. + char buf[32]; + + memset(buf, 0, sizeof(buf)); + ASSERT_EQ(0x8a9136aa, Value(buf, sizeof(buf))); + + memset(buf, 0xff, sizeof(buf)); + ASSERT_EQ(0x62a8ab43, Value(buf, sizeof(buf))); + + for (int i = 0; i < 32; i++) { + buf[i] = i; + } + ASSERT_EQ(0x46dd794e, Value(buf, sizeof(buf))); + + for (int i = 0; i < 32; i++) { + buf[i] = 31 - i; + } + ASSERT_EQ(0x113fdb5c, Value(buf, sizeof(buf))); + + unsigned char data[48] = { + 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + ASSERT_EQ(0xd9963a56, Value(reinterpret_cast(data), sizeof(data))); +} + +TEST(CRC, Values) { ASSERT_NE(Value("a", 1), Value("foo", 3)); } + +TEST(CRC, Extend) { + ASSERT_EQ(Value("hello world", 11), Extend(Value("hello ", 6), "world", 5)); +} + +TEST(CRC, Mask) { + uint32 crc = Value("foo", 3); + ASSERT_NE(crc, Mask(crc)); + ASSERT_NE(crc, Mask(Mask(crc))); + ASSERT_EQ(crc, Unmask(Mask(crc))); + ASSERT_EQ(crc, Unmask(Unmask(Mask(Mask(crc))))); +} + +} // namespace crc32c +} // namespace tensorflow diff --git a/tensorflow/core/lib/hash/hash.cc b/tensorflow/core/lib/hash/hash.cc new file mode 100644 index 0000000000..075d252412 --- /dev/null +++ b/tensorflow/core/lib/hash/hash.cc @@ -0,0 +1,113 @@ +#include "tensorflow/core/lib/hash/hash.h" + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/raw_coding.h" + +#include + +namespace tensorflow { + +// 0xff is in case char is signed. +static inline uint32 ByteAs32(char c) { return static_cast(c) & 0xff; } +static inline uint64 ByteAs64(char c) { return static_cast(c) & 0xff; } + +uint32 Hash32(const char* data, size_t n, uint32 seed) { + // 'm' and 'r' are mixing constants generated offline. + // They're not really 'magic', they just happen to work well. + + const uint32 m = 0x5bd1e995; + const int r = 24; + + // Initialize the hash to a 'random' value + uint32 h = seed ^ n; + + // Mix 4 bytes at a time into the hash + while (n >= 4) { + uint32 k = core::DecodeFixed32(data); + + k *= m; + k ^= k >> r; + k *= m; + + h *= m; + h ^= k; + + data += 4; + n -= 4; + } + + // Handle the last few bytes of the input array + + switch (n) { + case 3: + h ^= ByteAs32(data[2]) << 16; + TF_FALLTHROUGH_INTENDED; + case 2: + h ^= ByteAs32(data[1]) << 8; + TF_FALLTHROUGH_INTENDED; + case 1: + h ^= ByteAs32(data[0]); + h *= m; + } + + // Do a few final mixes of the hash to ensure the last few + // bytes are well-incorporated. + + h ^= h >> 13; + h *= m; + h ^= h >> 15; + + return h; +} + +uint64 Hash64(const char* data, size_t n, uint64 seed) { + const uint64 m = 0xc6a4a7935bd1e995; + const int r = 47; + + uint64 h = seed ^ (n * m); + + while (n >= 8) { + uint64 k = core::DecodeFixed64(data); + data += 8; + n -= 8; + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + switch (n) { + case 7: + h ^= ByteAs64(data[6]) << 48; + TF_FALLTHROUGH_INTENDED; + case 6: + h ^= ByteAs64(data[5]) << 40; + TF_FALLTHROUGH_INTENDED; + case 5: + h ^= ByteAs64(data[4]) << 32; + TF_FALLTHROUGH_INTENDED; + case 4: + h ^= ByteAs64(data[3]) << 24; + TF_FALLTHROUGH_INTENDED; + case 3: + h ^= ByteAs64(data[2]) << 16; + TF_FALLTHROUGH_INTENDED; + case 2: + h ^= ByteAs64(data[1]) << 8; + TF_FALLTHROUGH_INTENDED; + case 1: + h ^= ByteAs64(data[0]); + h *= m; + } + + h ^= h >> r; + h *= m; + h ^= h >> r; + + return h; +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h new file mode 100644 index 0000000000..af56218fed --- /dev/null +++ b/tensorflow/core/lib/hash/hash.h @@ -0,0 +1,28 @@ +// Simple hash functions used for internal data structures + +#ifndef TENSORFLOW_LIB_HASH_HASH_H_ +#define TENSORFLOW_LIB_HASH_HASH_H_ + +#include +#include + +#include + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +extern uint32 Hash32(const char* data, size_t n, uint32 seed); +extern uint64 Hash64(const char* data, size_t n, uint64 seed); + +inline uint64 Hash64(const char* data, size_t n) { + return Hash64(data, n, 0xDECAFCAFFE); +} + +inline uint64 Hash64(const string& str) { + return Hash64(str.data(), str.size()); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_HASH_HASH_H_ diff --git a/tensorflow/core/lib/hash/hash_test.cc b/tensorflow/core/lib/hash/hash_test.cc new file mode 100644 index 0000000000..9d3b970f3b --- /dev/null +++ b/tensorflow/core/lib/hash/hash_test.cc @@ -0,0 +1,64 @@ +#include + +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include + +namespace tensorflow { + +TEST(Hash, SignedUnsignedIssue) { + const unsigned char d1[1] = {0x62}; + const unsigned char d2[2] = {0xc3, 0x97}; + const unsigned char d3[3] = {0xe2, 0x99, 0xa5}; + const unsigned char d4[4] = {0xe1, 0x80, 0xb9, 0x32}; + const unsigned char d5[48] = { + 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + + struct Case { + uint32 hash32; + uint64 hash64; + const unsigned char* data; + size_t size; + uint32 seed; + }; + + for (Case c : std::vector{ + {0x471a8188u, 0x4c61ea3eeda4cb87ull, nullptr, 0, 0xbc9f1d34}, + {0xd615eba5u, 0x091309f7ef916c8aull, d1, sizeof(d1), 0xbc9f1d34}, + {0x0c3cccdau, 0xa815bcdf1d1af01cull, d2, sizeof(d2), 0xbc9f1d34}, + {0x3ba37e0eu, 0x02167564e4d06430ull, d3, sizeof(d3), 0xbc9f1d34}, + {0x16174eb3u, 0x8f7ed82ffc21071full, d4, sizeof(d4), 0xbc9f1d34}, + {0x98b1926cu, 0xce196580c97aff1eull, d5, sizeof(d5), 0x12345678}, + }) { + EXPECT_EQ(c.hash32, + Hash32(reinterpret_cast(c.data), c.size, c.seed)); + EXPECT_EQ(c.hash64, + Hash64(reinterpret_cast(c.data), c.size, c.seed)); + + // Check hashes with inputs aligned differently. + for (int align = 1; align <= 7; align++) { + std::string input(align, 'x'); + input.append(reinterpret_cast(c.data), c.size); + EXPECT_EQ(c.hash32, Hash32(&input[align], c.size, c.seed)); + EXPECT_EQ(c.hash64, Hash64(&input[align], c.size, c.seed)); + } + } +} + +static void BM_Hash32(int iters, int len) { + std::string input(len, 'x'); + uint32 h = 0; + for (int i = 0; i < iters; i++) { + h = Hash32(input.data(), len, 1); + } + testing::BytesProcessed(static_cast(iters) * len); + VLOG(1) << h; +} +BENCHMARK(BM_Hash32)->Range(1, 1024); + +} // namespace tensorflow diff --git a/tensorflow/core/lib/histogram/histogram.cc b/tensorflow/core/lib/histogram/histogram.cc new file mode 100644 index 0000000000..4c29d687b7 --- /dev/null +++ b/tensorflow/core/lib/histogram/histogram.cc @@ -0,0 +1,247 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/histogram/histogram.h" +#include +#include +#include "tensorflow/core/framework/summary.pb.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +namespace tensorflow { +namespace histogram { + +static std::vector* InitDefaultBucketsInner() { + std::vector buckets; + std::vector neg_buckets; + // Make buckets whose range grows by 10% starting at 1.0e-12 up to 1.0e20 + double v = 1.0e-12; + while (v < 1.0e20) { + buckets.push_back(v); + neg_buckets.push_back(-v); + v *= 1.1; + } + buckets.push_back(DBL_MAX); + neg_buckets.push_back(-DBL_MAX); + std::reverse(neg_buckets.begin(), neg_buckets.end()); + std::vector* result = new std::vector; + result->insert(result->end(), neg_buckets.begin(), neg_buckets.end()); + result->push_back(0.0); + result->insert(result->end(), buckets.begin(), buckets.end()); + return result; +} + +static gtl::ArraySlice InitDefaultBuckets() { + static std::vector* default_bucket_limits = InitDefaultBucketsInner(); + return *default_bucket_limits; +} + +Histogram::Histogram() : bucket_limits_(InitDefaultBuckets()) { Clear(); } + +// Create a histogram with a custom set of bucket limits, +// specified in "custom_buckets[0..custom_buckets.size()-1]" +Histogram::Histogram(gtl::ArraySlice custom_bucket_limits) + : custom_bucket_limits_(custom_bucket_limits.begin(), + custom_bucket_limits.end()), + bucket_limits_(custom_bucket_limits_) { +#ifndef NDEBUG + DCHECK_GT(bucket_limits_.size(), 0); + // Verify that the bucket boundaries are strictly increasing + for (size_t i = 1; i < bucket_limits_.size(); i++) { + DCHECK_GT(bucket_limits_[i], bucket_limits_[i - 1]); + } +#endif + Clear(); +} + +bool Histogram::DecodeFromProto(const HistogramProto& proto) { + if ((proto.bucket_size() != proto.bucket_limit_size()) || + (proto.bucket_size() == 0)) { + return false; + } + min_ = proto.min(); + max_ = proto.max(); + num_ = proto.num(); + sum_ = proto.sum(); + sum_squares_ = proto.sum_squares(); + custom_bucket_limits_.clear(); + custom_bucket_limits_.insert(custom_bucket_limits_.end(), + proto.bucket_limit().begin(), + proto.bucket_limit().end()); + bucket_limits_ = custom_bucket_limits_; + buckets_.clear(); + buckets_.insert(buckets_.end(), proto.bucket().begin(), proto.bucket().end()); + return true; +} + +void Histogram::Clear() { + min_ = bucket_limits_[bucket_limits_.size() - 1]; + max_ = -DBL_MAX; + num_ = 0; + sum_ = 0; + sum_squares_ = 0; + buckets_.resize(bucket_limits_.size()); + for (size_t i = 0; i < bucket_limits_.size(); i++) { + buckets_[i] = 0; + } +} + +void Histogram::Add(double value) { + int b = + std::upper_bound(bucket_limits_.begin(), bucket_limits_.end(), value) - + bucket_limits_.begin(); + + buckets_[b] += 1.0; + if (min_ > value) min_ = value; + if (max_ < value) max_ = value; + num_++; + sum_ += value; + sum_squares_ += (value * value); +} + +double Histogram::Median() const { return Percentile(50.0); } + +double Histogram::Percentile(double p) const { + if (num_ == 0.0) return 0.0; + double threshold = num_ * (p / 100.0); + double sum = 0; + for (size_t b = 0; b < buckets_.size(); b++) { + sum += buckets_[b]; + if (sum >= threshold) { + // Scale linearly within this bucket + double left_point = (b == 0) ? min_ : bucket_limits_[b - 1]; + double right_point = bucket_limits_[b]; + double left_sum = sum - buckets_[b]; + double right_sum = sum; + double pos = (threshold - left_sum) / (right_sum - left_sum); + double r = left_point + (right_point - left_point) * pos; + if (r < min_) r = min_; + if (r > max_) r = max_; + return r; + } + } + return max_; +} + +double Histogram::Average() const { + if (num_ == 0.0) return 0; + return sum_ / num_; +} + +double Histogram::StandardDeviation() const { + if (num_ == 0.0) return 0; + double variance = (sum_squares_ * num_ - sum_ * sum_) / (num_ * num_); + return sqrt(variance); +} + +std::string Histogram::ToString() const { + std::string r; + char buf[200]; + snprintf(buf, sizeof(buf), "Count: %.0f Average: %.4f StdDev: %.2f\n", num_, + Average(), StandardDeviation()); + r.append(buf); + snprintf(buf, sizeof(buf), "Min: %.4f Median: %.4f Max: %.4f\n", + (num_ == 0.0 ? 0.0 : min_), Median(), max_); + r.append(buf); + r.append("------------------------------------------------------\n"); + const double mult = num_ > 0 ? 100.0 / num_ : 0.0; + double sum = 0; + for (size_t b = 0; b < buckets_.size(); b++) { + if (buckets_[b] <= 0.0) continue; + sum += buckets_[b]; + snprintf(buf, sizeof(buf), "[ %10.2g, %10.2g ) %7.0f %7.3f%% %7.3f%% ", + ((b == 0) ? -DBL_MAX : bucket_limits_[b - 1]), // left + bucket_limits_[b], // right + buckets_[b], // count + mult * buckets_[b], // percentage + mult * sum); // cum percentage + r.append(buf); + + // Add hash marks based on percentage; 20 marks for 100%. + int marks = static_cast(20 * (buckets_[b] / num_) + 0.5); + r.append(marks, '#'); + r.push_back('\n'); + } + return r; +} + +void Histogram::EncodeToProto(HistogramProto* proto, + bool preserve_zero_buckets) const { + proto->Clear(); + proto->set_min(min_); + proto->set_max(max_); + proto->set_num(num_); + proto->set_sum(sum_); + proto->set_sum_squares(sum_squares_); + for (size_t i = 0; i < buckets_.size();) { + double end = bucket_limits_[i]; + double count = buckets_[i]; + i++; + if (!preserve_zero_buckets && count <= 0.0) { + // Find run of empty buckets and collapse them into one + while (i < buckets_.size() && buckets_[i] <= 0.0) { + end = bucket_limits_[i]; + count = buckets_[i]; + i++; + } + } + proto->add_bucket_limit(end); + proto->add_bucket(count); + } + if (proto->bucket_size() == 0.0) { + // It's easier when we restore if we always have at least one bucket entry + proto->add_bucket_limit(DBL_MAX); + proto->add_bucket(0.0); + } +} + +// ThreadSafeHistogram implementation. +bool ThreadSafeHistogram::DecodeFromProto(const HistogramProto& proto) { + mutex_lock l(mu_); + return histogram_.DecodeFromProto(proto); +} + +void ThreadSafeHistogram::Clear() { + mutex_lock l(mu_); + histogram_.Clear(); +} + +void ThreadSafeHistogram::Add(double value) { + mutex_lock l(mu_); + histogram_.Add(value); +} + +void ThreadSafeHistogram::EncodeToProto(HistogramProto* proto, + bool preserve_zero_buckets) const { + mutex_lock l(mu_); + histogram_.EncodeToProto(proto, preserve_zero_buckets); +} + +double ThreadSafeHistogram::Median() const { + mutex_lock l(mu_); + return histogram_.Median(); +} + +double ThreadSafeHistogram::Percentile(double p) const { + mutex_lock l(mu_); + return histogram_.Percentile(p); +} + +double ThreadSafeHistogram::Average() const { + mutex_lock l(mu_); + return histogram_.Average(); +} + +double ThreadSafeHistogram::StandardDeviation() const { + mutex_lock l(mu_); + return histogram_.StandardDeviation(); +} + +std::string ThreadSafeHistogram::ToString() const { + mutex_lock l(mu_); + return histogram_.ToString(); +} + +} // namespace histogram +} // namespace tensorflow diff --git a/tensorflow/core/lib/histogram/histogram.h b/tensorflow/core/lib/histogram/histogram.h new file mode 100644 index 0000000000..9b655f3acb --- /dev/null +++ b/tensorflow/core/lib/histogram/histogram.h @@ -0,0 +1,119 @@ +#ifndef TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_ +#define TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_ + +#include +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +class HistogramProto; + +namespace histogram { + +class Histogram { + public: + // Create a histogram with a default set of bucket boundaries. + // Buckets near zero cover very small ranges (e.g. 10^-12), and each + // bucket range grows by ~10% as we head away from zero. The + // buckets cover the range from -DBL_MAX to DBL_MAX. + Histogram(); + + // Create a histogram with a custom set of bucket boundaries, + // specified in "custom_bucket_limits[0..custom_bucket_limits.size()-1]" + // REQUIRES: custom_bucket_limits[i] values are monotonically increasing. + // REQUIRES: custom_bucket_limits is not empty() + explicit Histogram(gtl::ArraySlice custom_bucket_limits); + + // Restore the state of a histogram that was previously encoded + // via Histogram::EncodeToProto. Note that only the bucket boundaries + // generated by EncodeToProto will be restored. + bool DecodeFromProto(const HistogramProto& proto); + + ~Histogram() {} + + void Clear(); + void Add(double value); + + // Save the current state of the histogram to "*proto". If + // "preserve_zero_buckets" is false, only non-zero bucket values and + // ranges are saved, and the bucket boundaries of zero-valued buckets + // are lost. + void EncodeToProto(HistogramProto* proto, bool preserve_zero_buckets) const; + + // Return the median of the values in the histogram + double Median() const; + + // Return the "p"th percentile [0.0..100.0] of the values in the + // distribution + double Percentile(double p) const; + + // Return the average value of the distribution + double Average() const; + + // Return the standard deviation of values in the distribution + double StandardDeviation() const; + + // Returns a multi-line human-readable string representing the histogram + // contents. Example output: + // Count: 4 Average: 251.7475 StdDev: 432.02 + // Min: -3.0000 Median: 5.0000 Max: 1000.0000 + // ------------------------------------------------------ + // [ -5, 0 ) 1 25.000% 25.000% ##### + // [ 0, 5 ) 1 25.000% 50.000% ##### + // [ 5, 10 ) 1 25.000% 75.000% ##### + // [ 1000, 10000 ) 1 25.000% 100.000% ##### + std::string ToString() const; + + private: + double min_; + double max_; + double num_; + double sum_; + double sum_squares_; + + std::vector custom_bucket_limits_; + gtl::ArraySlice bucket_limits_; + std::vector buckets_; + + TF_DISALLOW_COPY_AND_ASSIGN(Histogram); +}; + +// Wrapper around a Histogram object that is thread safe. +// +// All methods hold a lock while delegating to a Histogram object owned by the +// ThreadSafeHistogram instance. +// +// See Histogram for documentation of the methods. +class ThreadSafeHistogram { + public: + ThreadSafeHistogram() {} + explicit ThreadSafeHistogram(gtl::ArraySlice custom_bucket_limits) + : histogram_(custom_bucket_limits) {} + bool DecodeFromProto(const HistogramProto& proto); + + ~ThreadSafeHistogram() {} + + void Clear(); + + // TODO(mdevin): It might be a good idea to provide a AddN() + // method to avoid grabbing/releasing the lock when adding many values. + void Add(double value); + + void EncodeToProto(HistogramProto* proto, bool preserve_zero_buckets) const; + double Median() const; + double Percentile(double p) const; + double Average() const; + double StandardDeviation() const; + std::string ToString() const; + + private: + mutable mutex mu_; + Histogram histogram_ GUARDED_BY(mu_); +}; + +} // namespace histogram +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_ diff --git a/tensorflow/core/lib/histogram/histogram_test.cc b/tensorflow/core/lib/histogram/histogram_test.cc new file mode 100644 index 0000000000..ede44fe85b --- /dev/null +++ b/tensorflow/core/lib/histogram/histogram_test.cc @@ -0,0 +1,112 @@ +#include "tensorflow/core/lib/histogram/histogram.h" +#include +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/framework/summary.pb.h" +#include + +namespace tensorflow { +namespace histogram { + +static void Validate(const Histogram& h) { + string s1 = h.ToString(); + LOG(ERROR) << s1; + + HistogramProto proto_with_zeroes; + h.EncodeToProto(&proto_with_zeroes, true); + Histogram h2; + EXPECT_TRUE(h2.DecodeFromProto(proto_with_zeroes)); + string s2 = h2.ToString(); + LOG(ERROR) << s2; + + EXPECT_EQ(s1, s2); + + HistogramProto proto_no_zeroes; + h.EncodeToProto(&proto_no_zeroes, false); + LOG(ERROR) << proto_no_zeroes.DebugString(); + Histogram h3; + EXPECT_TRUE(h3.DecodeFromProto(proto_no_zeroes)); + string s3 = h3.ToString(); + LOG(ERROR) << s3; + + EXPECT_EQ(s1, s3); +} + +TEST(Histogram, Empty) { + Histogram h; + Validate(h); +} + +TEST(Histogram, SingleValue) { + Histogram h; + h.Add(-3.0); + Validate(h); +} + +TEST(Histogram, CustomBuckets) { + Histogram h({-10, -5, 0, 5, 10, 100, 1000, 10000, DBL_MAX}); + h.Add(-3.0); + h.Add(4.99); + h.Add(5.0); + h.Add(1000.0); + Validate(h); +} + +TEST(Histogram, Percentile) { + Histogram h({0, 10, 100, DBL_MAX}); + h.Add(-2); + h.Add(-2); + h.Add(0); + double median = h.Percentile(50.0); + EXPECT_EQ(median, -0.5); +} + +TEST(Histogram, Basic) { + Histogram h; + for (int i = 0; i < 100; i++) { + h.Add(i); + } + for (int i = 1000; i < 100000; i += 1000) { + h.Add(i); + } + Validate(h); +} + +TEST(ThreadSafeHistogram, Basic) { + // Fill a normal histogram. + Histogram h; + for (int i = 0; i < 100; i++) { + h.Add(i); + } + + // Fill a thread-safe histogram with the same values. + ThreadSafeHistogram tsh; + for (int i = 0; i < 100; i++) { + tsh.Add(i); + } + + for (int i = 0; i < 2; ++i) { + bool preserve_zero_buckets = (i == 0); + HistogramProto h_proto; + h.EncodeToProto(&h_proto, preserve_zero_buckets); + HistogramProto tsh_proto; + tsh.EncodeToProto(&tsh_proto, preserve_zero_buckets); + + // Let's decode from the proto of the other histogram type. + Histogram h2; + EXPECT_TRUE(h2.DecodeFromProto(tsh_proto)); + ThreadSafeHistogram tsh2; + EXPECT_TRUE(tsh2.DecodeFromProto(h_proto)); + + // Now let's reencode and check they match. + EXPECT_EQ(h2.ToString(), tsh2.ToString()); + } + + EXPECT_EQ(h.Median(), tsh.Median()); + EXPECT_EQ(h.Percentile(40.0), tsh.Percentile(40.0)); + EXPECT_EQ(h.Average(), tsh.Average()); + EXPECT_EQ(h.StandardDeviation(), tsh.StandardDeviation()); + EXPECT_EQ(h.ToString(), tsh.ToString()); +} + +} // namespace histogram +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/block.cc b/tensorflow/core/lib/io/block.cc new file mode 100644 index 0000000000..1ddaa2eb78 --- /dev/null +++ b/tensorflow/core/lib/io/block.cc @@ -0,0 +1,236 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// Decodes the blocks generated by block_builder.cc. + +#include "tensorflow/core/lib/io/block.h" + +#include +#include +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace table { + +inline uint32 Block::NumRestarts() const { + assert(size_ >= sizeof(uint32)); + return core::DecodeFixed32(data_ + size_ - sizeof(uint32)); +} + +Block::Block(const BlockContents& contents) + : data_(contents.data.data()), + size_(contents.data.size()), + owned_(contents.heap_allocated) { + if (size_ < sizeof(uint32)) { + size_ = 0; // Error marker + } else { + size_t max_restarts_allowed = (size_ - sizeof(uint32)) / sizeof(uint32); + if (NumRestarts() > max_restarts_allowed) { + // The size is too small for NumRestarts() + size_ = 0; + } else { + restart_offset_ = size_ - (1 + NumRestarts()) * sizeof(uint32); + } + } +} + +Block::~Block() { + if (owned_) { + delete[] data_; + } +} + +// Helper routine: decode the next block entry starting at "p", +// storing the number of shared key bytes, non_shared key bytes, +// and the length of the value in "*shared", "*non_shared", and +// "*value_length", respectively. Will not dereference past "limit". +// +// If any errors are detected, returns NULL. Otherwise, returns a +// pointer to the key delta (just past the three decoded values). +static inline const char* DecodeEntry(const char* p, const char* limit, + uint32* shared, uint32* non_shared, + uint32* value_length) { + if (limit - p < 3) return NULL; + *shared = reinterpret_cast(p)[0]; + *non_shared = reinterpret_cast(p)[1]; + *value_length = reinterpret_cast(p)[2]; + if ((*shared | *non_shared | *value_length) < 128) { + // Fast path: all three values are encoded in one byte each + p += 3; + } else { + if ((p = core::GetVarint32Ptr(p, limit, shared)) == NULL) return NULL; + if ((p = core::GetVarint32Ptr(p, limit, non_shared)) == NULL) return NULL; + if ((p = core::GetVarint32Ptr(p, limit, value_length)) == NULL) return NULL; + } + + if (static_cast(limit - p) < (*non_shared + *value_length)) { + return NULL; + } + return p; +} + +class Block::Iter : public Iterator { + private: + const char* const data_; // underlying block contents + uint32 const restarts_; // Offset of restart array (list of fixed32) + uint32 const num_restarts_; // Number of uint32 entries in restart array + + // current_ is offset in data_ of current entry. >= restarts_ if !Valid + uint32 current_; + uint32 restart_index_; // Index of restart block in which current_ falls + string key_; + StringPiece value_; + Status status_; + + inline int Compare(const StringPiece& a, const StringPiece& b) const { + return a.compare(b); + } + + // Return the offset in data_ just past the end of the current entry. + inline uint32 NextEntryOffset() const { + return (value_.data() + value_.size()) - data_; + } + + uint32 GetRestartPoint(uint32 index) { + assert(index < num_restarts_); + return core::DecodeFixed32(data_ + restarts_ + index * sizeof(uint32)); + } + + void SeekToRestartPoint(uint32 index) { + key_.clear(); + restart_index_ = index; + // current_ will be fixed by ParseNextKey(); + + // ParseNextKey() starts at the end of value_, so set value_ accordingly + uint32 offset = GetRestartPoint(index); + value_ = StringPiece(data_ + offset, 0); + } + + public: + Iter(const char* data, uint32 restarts, uint32 num_restarts) + : data_(data), + restarts_(restarts), + num_restarts_(num_restarts), + current_(restarts_), + restart_index_(num_restarts_) { + assert(num_restarts_ > 0); + } + + virtual bool Valid() const { return current_ < restarts_; } + virtual Status status() const { return status_; } + virtual StringPiece key() const { + assert(Valid()); + return key_; + } + virtual StringPiece value() const { + assert(Valid()); + return value_; + } + + virtual void Next() { + assert(Valid()); + ParseNextKey(); + } + + virtual void Seek(const StringPiece& target) { + // Binary search in restart array to find the last restart point + // with a key < target + uint32 left = 0; + uint32 right = num_restarts_ - 1; + while (left < right) { + uint32 mid = (left + right + 1) / 2; + uint32 region_offset = GetRestartPoint(mid); + uint32 shared, non_shared, value_length; + const char* key_ptr = + DecodeEntry(data_ + region_offset, data_ + restarts_, &shared, + &non_shared, &value_length); + if (key_ptr == NULL || (shared != 0)) { + CorruptionError(); + return; + } + StringPiece mid_key(key_ptr, non_shared); + if (Compare(mid_key, target) < 0) { + // Key at "mid" is smaller than "target". Therefore all + // blocks before "mid" are uninteresting. + left = mid; + } else { + // Key at "mid" is >= "target". Therefore all blocks at or + // after "mid" are uninteresting. + right = mid - 1; + } + } + + // Linear search (within restart block) for first key >= target + SeekToRestartPoint(left); + while (true) { + if (!ParseNextKey()) { + return; + } + if (Compare(key_, target) >= 0) { + return; + } + } + } + + virtual void SeekToFirst() { + SeekToRestartPoint(0); + ParseNextKey(); + } + + private: + void CorruptionError() { + current_ = restarts_; + restart_index_ = num_restarts_; + status_ = errors::DataLoss("bad entry in block"); + key_.clear(); + value_.clear(); + } + + bool ParseNextKey() { + current_ = NextEntryOffset(); + const char* p = data_ + current_; + const char* limit = data_ + restarts_; // Restarts come right after data + if (p >= limit) { + // No more entries to return. Mark as invalid. + current_ = restarts_; + restart_index_ = num_restarts_; + return false; + } + + // Decode next entry + uint32 shared, non_shared, value_length; + p = DecodeEntry(p, limit, &shared, &non_shared, &value_length); + if (p == NULL || key_.size() < shared) { + CorruptionError(); + return false; + } else { + key_.resize(shared); + key_.append(p, non_shared); + value_ = StringPiece(p + non_shared, value_length); + while (restart_index_ + 1 < num_restarts_ && + GetRestartPoint(restart_index_ + 1) < current_) { + ++restart_index_; + } + return true; + } + } +}; + +Iterator* Block::NewIterator() { + if (size_ < sizeof(uint32)) { + return NewErrorIterator(errors::DataLoss("bad block contents")); + } + const uint32 num_restarts = NumRestarts(); + if (num_restarts == 0) { + return NewEmptyIterator(); + } else { + return new Iter(data_, restart_offset_, num_restarts); + } +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/block.h b/tensorflow/core/lib/io/block.h new file mode 100644 index 0000000000..bf53245b8d --- /dev/null +++ b/tensorflow/core/lib/io/block.h @@ -0,0 +1,45 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_BLOCK_H_ +#define TENSORFLOW_LIB_IO_BLOCK_H_ + +#include +#include +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { + +struct BlockContents; + +class Block { + public: + // Initialize the block with the specified contents. + explicit Block(const BlockContents& contents); + + ~Block(); + + size_t size() const { return size_; } + Iterator* NewIterator(); + + private: + uint32 NumRestarts() const; + + const char* data_; + size_t size_; + uint32 restart_offset_; // Offset in data_ of restart array + bool owned_; // Block owns data_[] + + // No copying allowed + Block(const Block&); + void operator=(const Block&); + + class Iter; +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_BLOCK_H_ diff --git a/tensorflow/core/lib/io/block_builder.cc b/tensorflow/core/lib/io/block_builder.cc new file mode 100644 index 0000000000..d94048d744 --- /dev/null +++ b/tensorflow/core/lib/io/block_builder.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// BlockBuilder generates blocks where keys are prefix-compressed: +// +// When we store a key, we drop the prefix shared with the previous +// string. This helps reduce the space requirement significantly. +// Furthermore, once every K keys, we do not apply the prefix +// compression and store the entire key. We call this a "restart +// point". The tail end of the block stores the offsets of all of the +// restart points, and can be used to do a binary search when looking +// for a particular key. Values are stored as-is (without compression) +// immediately following the corresponding key. +// +// An entry for a particular key-value pair has the form: +// shared_bytes: varint32 +// unshared_bytes: varint32 +// value_length: varint32 +// key_delta: char[unshared_bytes] +// value: char[value_length] +// shared_bytes == 0 for restart points. +// +// The trailer of the block has the form: +// restarts: uint32[num_restarts] +// num_restarts: uint32 +// restarts[i] contains the offset within the block of the ith restart point. + +#include "tensorflow/core/lib/io/block_builder.h" + +#include +#include +#include "tensorflow/core/lib/io/table_builder.h" +#include "tensorflow/core/lib/core/coding.h" + +namespace tensorflow { +namespace table { + +BlockBuilder::BlockBuilder(const Options* options) + : options_(options), restarts_(), counter_(0), finished_(false) { + assert(options->block_restart_interval >= 1); + restarts_.push_back(0); // First restart point is at offset 0 +} + +void BlockBuilder::Reset() { + buffer_.clear(); + restarts_.clear(); + restarts_.push_back(0); // First restart point is at offset 0 + counter_ = 0; + finished_ = false; + last_key_.clear(); +} + +size_t BlockBuilder::CurrentSizeEstimate() const { + return (buffer_.size() + // Raw data buffer + restarts_.size() * sizeof(uint32) + // Restart array + sizeof(uint32)); // Restart array length +} + +StringPiece BlockBuilder::Finish() { + // Append restart array + for (size_t i = 0; i < restarts_.size(); i++) { + core::PutFixed32(&buffer_, restarts_[i]); + } + core::PutFixed32(&buffer_, restarts_.size()); + finished_ = true; + return StringPiece(buffer_); +} + +void BlockBuilder::Add(const StringPiece& key, const StringPiece& value) { + StringPiece last_key_piece(last_key_); + assert(!finished_); + assert(counter_ <= options_->block_restart_interval); + assert(buffer_.empty() // No values yet? + || key.compare(last_key_piece) > 0); + size_t shared = 0; + if (counter_ < options_->block_restart_interval) { + // See how much sharing to do with previous string + const size_t min_length = std::min(last_key_piece.size(), key.size()); + while ((shared < min_length) && (last_key_piece[shared] == key[shared])) { + shared++; + } + } else { + // Restart compression + restarts_.push_back(buffer_.size()); + counter_ = 0; + } + const size_t non_shared = key.size() - shared; + + // Add "" to buffer_ + core::PutVarint32(&buffer_, shared); + core::PutVarint32(&buffer_, non_shared); + core::PutVarint32(&buffer_, value.size()); + + // Add string delta to buffer_ followed by value + buffer_.append(key.data() + shared, non_shared); + buffer_.append(value.data(), value.size()); + + // Update state + last_key_.resize(shared); + last_key_.append(key.data() + shared, non_shared); + assert(StringPiece(last_key_) == key); + counter_++; +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h new file mode 100644 index 0000000000..e07a647805 --- /dev/null +++ b/tensorflow/core/lib/io/block_builder.h @@ -0,0 +1,57 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_ +#define TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_ + +#include + +#include +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace table { + +struct Options; + +class BlockBuilder { + public: + explicit BlockBuilder(const Options* options); + + // Reset the contents as if the BlockBuilder was just constructed. + void Reset(); + + // REQUIRES: Finish() has not been called since the last call to Reset(). + // REQUIRES: key is larger than any previously added key + void Add(const StringPiece& key, const StringPiece& value); + + // Finish building the block and return a slice that refers to the + // block contents. The returned slice will remain valid for the + // lifetime of this builder or until Reset() is called. + StringPiece Finish(); + + // Returns an estimate of the current (uncompressed) size of the block + // we are building. + size_t CurrentSizeEstimate() const; + + // Return true iff no entries have been added since the last Reset() + bool empty() const { return buffer_.empty(); } + + private: + const Options* options_; + string buffer_; // Destination buffer + std::vector restarts_; // Restart points + int counter_; // Number of entries emitted since restart + bool finished_; // Has Finish() been called? + string last_key_; + + // No copying allowed + BlockBuilder(const BlockBuilder&); + void operator=(const BlockBuilder&); +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_ diff --git a/tensorflow/core/lib/io/format.cc b/tensorflow/core/lib/io/format.cc new file mode 100644 index 0000000000..259cfc13dc --- /dev/null +++ b/tensorflow/core/lib/io/format.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/format.h" + +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace table { + +void BlockHandle::EncodeTo(string* dst) const { + // Sanity check that all fields have been set + assert(offset_ != ~static_cast(0)); + assert(size_ != ~static_cast(0)); + core::PutVarint64(dst, offset_); + core::PutVarint64(dst, size_); +} + +Status BlockHandle::DecodeFrom(StringPiece* input) { + if (core::GetVarint64(input, &offset_) && core::GetVarint64(input, &size_)) { + return Status::OK(); + } else { + return errors::DataLoss("bad block handle"); + } +} + +void Footer::EncodeTo(string* dst) const { +#ifndef NDEBUG + const size_t original_size = dst->size(); +#endif + metaindex_handle_.EncodeTo(dst); + index_handle_.EncodeTo(dst); + dst->resize(2 * BlockHandle::kMaxEncodedLength); // Padding + core::PutFixed32(dst, static_cast(kTableMagicNumber & 0xffffffffu)); + core::PutFixed32(dst, static_cast(kTableMagicNumber >> 32)); + assert(dst->size() == original_size + kEncodedLength); +} + +Status Footer::DecodeFrom(StringPiece* input) { + const char* magic_ptr = input->data() + kEncodedLength - 8; + const uint32 magic_lo = core::DecodeFixed32(magic_ptr); + const uint32 magic_hi = core::DecodeFixed32(magic_ptr + 4); + const uint64 magic = + ((static_cast(magic_hi) << 32) | (static_cast(magic_lo))); + if (magic != kTableMagicNumber) { + return errors::DataLoss("not an sstable (bad magic number)"); + } + + Status result = metaindex_handle_.DecodeFrom(input); + if (result.ok()) { + result = index_handle_.DecodeFrom(input); + } + if (result.ok()) { + // We skip over any leftover data (just padding for now) in "input" + const char* end = magic_ptr + 8; + *input = StringPiece(end, input->data() + input->size() - end); + } + return result; +} + +Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, + BlockContents* result) { + result->data = StringPiece(); + result->cachable = false; + result->heap_allocated = false; + + // Read the block contents as well as the type/crc footer. + // See table_builder.cc for the code that built this structure. + size_t n = static_cast(handle.size()); + char* buf = new char[n + kBlockTrailerSize]; + StringPiece contents; + Status s = + file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf); + if (!s.ok()) { + delete[] buf; + return s; + } + if (contents.size() != n + kBlockTrailerSize) { + delete[] buf; + return errors::DataLoss("truncated block read"); + } + + // Check the crc of the type and the block contents + const char* data = contents.data(); // Pointer to where Read put the data + // This checksum verification is optional. We leave it on for now + const bool verify_checksum = true; + if (verify_checksum) { + const uint32 crc = crc32c::Unmask(core::DecodeFixed32(data + n + 1)); + const uint32 actual = crc32c::Value(data, n + 1); + if (actual != crc) { + delete[] buf; + s = errors::DataLoss("block checksum mismatch"); + return s; + } + } + + switch (data[n]) { + case kNoCompression: + if (data != buf) { + // File implementation gave us pointer to some other data. + // Use it directly under the assumption that it will be live + // while the file is open. + delete[] buf; + result->data = StringPiece(data, n); + result->heap_allocated = false; + result->cachable = false; // Do not double-cache + } else { + result->data = StringPiece(buf, n); + result->heap_allocated = true; + result->cachable = true; + } + + // Ok + break; + case kSnappyCompression: { + size_t ulength = 0; + if (!port::Snappy_GetUncompressedLength(data, n, &ulength)) { + delete[] buf; + return errors::DataLoss("corrupted compressed block contents"); + } + char* ubuf = new char[ulength]; + if (!port::Snappy_Uncompress(data, n, ubuf)) { + delete[] buf; + delete[] ubuf; + return errors::DataLoss("corrupted compressed block contents"); + } + delete[] buf; + result->data = StringPiece(ubuf, ulength); + result->heap_allocated = true; + result->cachable = true; + break; + } + default: + delete[] buf; + return errors::DataLoss("bad block type"); + } + + return Status::OK(); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/format.h b/tensorflow/core/lib/io/format.h new file mode 100644 index 0000000000..3121c41bb8 --- /dev/null +++ b/tensorflow/core/lib/io/format.h @@ -0,0 +1,99 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_FORMAT_H_ +#define TENSORFLOW_LIB_IO_FORMAT_H_ + +#include +#include +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/io/table_builder.h" + +namespace tensorflow { +class RandomAccessFile; +namespace table { + +class Block; + +// BlockHandle is a pointer to the extent of a file that stores a data +// block or a meta block. +class BlockHandle { + public: + BlockHandle(); + + // The offset of the block in the file. + uint64 offset() const { return offset_; } + void set_offset(uint64 offset) { offset_ = offset; } + + // The size of the stored block + uint64 size() const { return size_; } + void set_size(uint64 size) { size_ = size; } + + void EncodeTo(string* dst) const; + Status DecodeFrom(StringPiece* input); + + // Maximum encoding length of a BlockHandle + enum { kMaxEncodedLength = 10 + 10 }; + + private: + uint64 offset_; + uint64 size_; +}; + +// Footer encapsulates the fixed information stored at the tail +// end of every table file. +class Footer { + public: + Footer() {} + + // The block handle for the metaindex block of the table + const BlockHandle& metaindex_handle() const { return metaindex_handle_; } + void set_metaindex_handle(const BlockHandle& h) { metaindex_handle_ = h; } + + // The block handle for the index block of the table + const BlockHandle& index_handle() const { return index_handle_; } + void set_index_handle(const BlockHandle& h) { index_handle_ = h; } + + void EncodeTo(string* dst) const; + Status DecodeFrom(StringPiece* input); + + // Encoded length of a Footer. Note that the serialization of a + // Footer will always occupy exactly this many bytes. It consists + // of two block handles and a magic number. + enum { kEncodedLength = 2 * BlockHandle::kMaxEncodedLength + 8 }; + + private: + BlockHandle metaindex_handle_; + BlockHandle index_handle_; +}; + +// kTableMagicNumber was picked by running +// echo http://code.google.com/p/leveldb/ | sha1sum +// and taking the leading 64 bits. +static const uint64 kTableMagicNumber = 0xdb4775248b80fb57ull; + +// 1-byte type + 32-bit crc +static const size_t kBlockTrailerSize = 5; + +struct BlockContents { + StringPiece data; // Actual contents of data + bool cachable; // True iff data can be cached + bool heap_allocated; // True iff caller should delete[] data.data() +}; + +// Read the block identified by "handle" from "file". On failure +// return non-OK. On success fill *result and return OK. +extern Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, + BlockContents* result); + +// Implementation details follow. Clients should ignore, + +inline BlockHandle::BlockHandle() + : offset_(~static_cast(0)), size_(~static_cast(0)) {} + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_FORMAT_H_ diff --git a/tensorflow/core/lib/io/inputbuffer.cc b/tensorflow/core/lib/io/inputbuffer.cc new file mode 100644 index 0000000000..8fa245a546 --- /dev/null +++ b/tensorflow/core/lib/io/inputbuffer.cc @@ -0,0 +1,112 @@ +#include "tensorflow/core/lib/io/inputbuffer.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace io { + +InputBuffer::InputBuffer(RandomAccessFile* file, size_t buffer_bytes) + : file_(file), + file_pos_(0), + size_(buffer_bytes), + buf_(new char[size_]), + pos_(buf_), + limit_(buf_) {} + +InputBuffer::~InputBuffer() { + delete file_; + delete[] buf_; +} + +Status InputBuffer::FillBuffer() { + StringPiece data; + Status s = file_->Read(file_pos_, size_, &data, buf_); + if (data.data() != buf_) { + memmove(buf_, data.data(), data.size()); + } + pos_ = buf_; + limit_ = pos_ + data.size(); + file_pos_ += data.size(); + return s; +} + +Status InputBuffer::ReadLine(string* result) { + result->clear(); + int i; + Status s; + for (i = 0;; i++) { + if (pos_ == limit_) { + // Get more data into buffer + s = FillBuffer(); + if (limit_ == buf_) { + break; + } + } + char c = *pos_++; + if (c == '\n') { + // We don't append the '\n' to *result + return Status::OK(); + } + *result += c; + } + if (errors::IsOutOfRange(s) && !result->empty()) { + return Status::OK(); + } + return s; +} + +Status InputBuffer::ReadNBytes(int64 bytes_to_read, string* result) { + result->clear(); + if (bytes_to_read < 0) { + return errors::InvalidArgument("Can't read a negative number of bytes: ", + bytes_to_read); + } + result->reserve(bytes_to_read); + Status s; + while (result->size() < static_cast(bytes_to_read)) { + if (pos_ == limit_) { + // Get more data into buffer + s = FillBuffer(); + if (limit_ == buf_) { + break; + } + } + const int64 bytes_to_copy = + std::min(limit_ - pos_, bytes_to_read - result->size()); + result->insert(result->size(), pos_, bytes_to_copy); + pos_ += bytes_to_copy; + } + if (errors::IsOutOfRange(s) && + (result->size() == static_cast(bytes_to_read))) { + return Status::OK(); + } + return s; +} + +Status InputBuffer::SkipNBytes(int64 bytes_to_skip) { + if (bytes_to_skip < 0) { + return errors::InvalidArgument("Can only skip forward, not ", + bytes_to_skip); + } + int64 bytes_skipped = 0; + Status s; + while (bytes_skipped < bytes_to_skip) { + if (pos_ == limit_) { + // Get more data into buffer + s = FillBuffer(); + if (limit_ == buf_) { + break; + } + } + const int64 bytes_to_advance = + std::min(limit_ - pos_, bytes_to_skip - bytes_skipped); + bytes_skipped += bytes_to_advance; + pos_ += bytes_to_advance; + } + if (errors::IsOutOfRange(s) && bytes_skipped == bytes_to_skip) { + return Status::OK(); + } + return s; +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/inputbuffer.h b/tensorflow/core/lib/io/inputbuffer.h new file mode 100644 index 0000000000..6879f30567 --- /dev/null +++ b/tensorflow/core/lib/io/inputbuffer.h @@ -0,0 +1,62 @@ +#ifndef TENSORFLOW_LIB_IO_INPUTBUFFER_H_ +#define TENSORFLOW_LIB_IO_INPUTBUFFER_H_ + +#include +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace io { + +// An InputBuffer provides a buffer on top of a RandomAccessFile. +// A given instance of an InputBuffer is NOT safe for concurrent use +// by multiple threads +class InputBuffer { + public: + // Create an InputBuffer for "file" with a buffer size of + // "buffer_bytes" bytes. Takes ownership of "file" and will + // delete it when the InputBuffer is destroyed. + InputBuffer(RandomAccessFile* file, size_t buffer_bytes); + ~InputBuffer(); + + // Read one text line of data into "*result" until end-of-file or a + // \n is read. (The \n is not included in the result.) Overwrites + // any existing data in *result. + // + // If successful, returns OK. If we are already at the end of the + // file, we return an OUT_OF_RANGE error. Otherwise, we return + // some other non-OK status. + Status ReadLine(string* result); + + // Reads bytes_to_read bytes into *result, overwriting *result. + // + // If successful, returns OK. If we there are not enough bytes to + // read before the end of the file, we return an OUT_OF_RANGE error. + // Otherwise, we return some other non-OK status. + Status ReadNBytes(int64 bytes_to_read, string* result); + + // Like ReadNBytes() without returning the bytes read. + Status SkipNBytes(int64 bytes_to_skip); + + // Returns the position in the file. + int64 Tell() const { return file_pos_ - (limit_ - pos_); } + + private: + Status FillBuffer(); + + RandomAccessFile* file_; // Owned + int64 file_pos_; // Next position to read from in "file_" + size_t size_; // Size of "buf_" + char* buf_; // The buffer itself + // [pos_,limit_) hold the "limit_ - pos_" bytes just before "file_pos_" + char* pos_; // Current position in "buf" + char* limit_; // Just past end of valid data in "buf" + + TF_DISALLOW_COPY_AND_ASSIGN(InputBuffer); +}; + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_INPUTBUFFER_H_ diff --git a/tensorflow/core/lib/io/inputbuffer_test.cc b/tensorflow/core/lib/io/inputbuffer_test.cc new file mode 100644 index 0000000000..34094f018c --- /dev/null +++ b/tensorflow/core/lib/io/inputbuffer_test.cc @@ -0,0 +1,174 @@ +#include "tensorflow/core/lib/io/inputbuffer.h" + +#include "tensorflow/core/public/env.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +static std::vector BufferSizes() { + return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 65536}; +} + +TEST(InputBuffer, ReadLine_Empty) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, ""); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + +TEST(InputBuffer, ReadLine1) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "line one\nline two\nline three\n"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line one"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line two"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line three"); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + // A second call should also return end of file + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + +TEST(InputBuffer, ReadLine_NoTrailingNewLine) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "line one\nline two\nline three"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line one"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line two"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line three"); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + // A second call should also return end of file + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + +TEST(InputBuffer, ReadLine_EmptyLines) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "line one\n\n\nline two\nline three"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string line; + io::InputBuffer in(file, buf_size); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line one"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, ""); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, ""); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line two"); + TF_CHECK_OK(in.ReadLine(&line)); + EXPECT_EQ(line, "line three"); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + // A second call should also return end of file + EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line))); + } +} + +TEST(InputBuffer, ReadNBytes) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "0123456789"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string read; + io::InputBuffer in(file, buf_size); + EXPECT_EQ(0, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(3, &read)); + EXPECT_EQ(read, "012"); + EXPECT_EQ(3, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, ""); + EXPECT_EQ(3, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(4, &read)); + EXPECT_EQ(read, "3456"); + EXPECT_EQ(7, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, ""); + EXPECT_EQ(7, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read))); + EXPECT_EQ(read, "789"); + EXPECT_EQ(10, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read))); + EXPECT_EQ(read, ""); + EXPECT_EQ(10, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, ""); + EXPECT_EQ(10, in.Tell()); + } +} + +TEST(InputBuffer, SkipNBytes) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/inputbuffer_test"; + WriteStringToFile(env, fname, "0123456789"); + + for (auto buf_size : BufferSizes()) { + RandomAccessFile* file; + TF_CHECK_OK(env->NewRandomAccessFile(fname, &file)); + string read; + io::InputBuffer in(file, buf_size); + EXPECT_EQ(0, in.Tell()); + TF_CHECK_OK(in.SkipNBytes(3)); + EXPECT_EQ(3, in.Tell()); + TF_CHECK_OK(in.SkipNBytes(0)); + EXPECT_EQ(3, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(2, &read)); + EXPECT_EQ(read, "34"); + EXPECT_EQ(5, in.Tell()); + TF_CHECK_OK(in.SkipNBytes(0)); + EXPECT_EQ(5, in.Tell()); + TF_CHECK_OK(in.SkipNBytes(2)); + EXPECT_EQ(7, in.Tell()); + TF_CHECK_OK(in.ReadNBytes(1, &read)); + EXPECT_EQ(read, "7"); + EXPECT_EQ(8, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.SkipNBytes(5))); + EXPECT_EQ(10, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.SkipNBytes(5))); + EXPECT_EQ(10, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read))); + EXPECT_EQ(read, ""); + EXPECT_EQ(10, in.Tell()); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/iterator.cc b/tensorflow/core/lib/io/iterator.cc new file mode 100644 index 0000000000..878e93a911 --- /dev/null +++ b/tensorflow/core/lib/io/iterator.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { + +Iterator::Iterator() { + cleanup_.function = NULL; + cleanup_.next = NULL; +} + +Iterator::~Iterator() { + if (cleanup_.function != NULL) { + (*cleanup_.function)(cleanup_.arg1, cleanup_.arg2); + for (Cleanup* c = cleanup_.next; c != NULL;) { + (*c->function)(c->arg1, c->arg2); + Cleanup* next = c->next; + delete c; + c = next; + } + } +} + +void Iterator::RegisterCleanup(CleanupFunction func, void* arg1, void* arg2) { + assert(func != NULL); + Cleanup* c; + if (cleanup_.function == NULL) { + c = &cleanup_; + } else { + c = new Cleanup; + c->next = cleanup_.next; + cleanup_.next = c; + } + c->function = func; + c->arg1 = arg1; + c->arg2 = arg2; +} + +namespace { +class EmptyIterator : public Iterator { + public: + EmptyIterator(const Status& s) : status_(s) {} + virtual bool Valid() const { return false; } + virtual void Seek(const StringPiece& target) {} + virtual void SeekToFirst() {} + virtual void Next() { assert(false); } + StringPiece key() const { + assert(false); + return StringPiece(); + } + StringPiece value() const { + assert(false); + return StringPiece(); + } + virtual Status status() const { return status_; } + + private: + Status status_; +}; +} // namespace + +Iterator* NewEmptyIterator() { return new EmptyIterator(Status::OK()); } + +Iterator* NewErrorIterator(const Status& status) { + return new EmptyIterator(status); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/iterator.h b/tensorflow/core/lib/io/iterator.h new file mode 100644 index 0000000000..603a2f95fe --- /dev/null +++ b/tensorflow/core/lib/io/iterator.h @@ -0,0 +1,93 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// An iterator yields a sequence of key/value pairs from a source. +// The following class defines the interface. Multiple implementations +// are provided by this library. In particular, iterators are provided +// to access the contents of a Table or a DB. +// +// Multiple threads can invoke const methods on an Iterator without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same Iterator must use +// external synchronization. + +#ifndef TENSORFLOW_LIB_IO_ITERATOR_H_ +#define TENSORFLOW_LIB_IO_ITERATOR_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace table { + +class Iterator { + public: + Iterator(); + virtual ~Iterator(); + + // An iterator is either positioned at a key/value pair, or + // not valid. This method returns true iff the iterator is valid. + virtual bool Valid() const = 0; + + // Position at the first key in the source. The iterator is Valid() + // after this call iff the source is not empty. + virtual void SeekToFirst() = 0; + + // Position at the first key in the source that is at or past target. + // The iterator is Valid() after this call iff the source contains + // an entry that comes at or past target. + virtual void Seek(const StringPiece& target) = 0; + + // Moves to the next entry in the source. After this call, Valid() is + // true iff the iterator was not positioned at the last entry in the source. + // REQUIRES: Valid() + virtual void Next() = 0; + + // Return the key for the current entry. The underlying storage for + // the returned slice is valid only until the next modification of + // the iterator. + // REQUIRES: Valid() + virtual StringPiece key() const = 0; + + // Return the value for the current entry. The underlying storage for + // the returned slice is valid only until the next modification of + // the iterator. + // REQUIRES: Valid() + virtual StringPiece value() const = 0; + + // If an error has occurred, return it. Else return an ok status. + virtual Status status() const = 0; + + // Clients are allowed to register function/arg1/arg2 triples that + // will be invoked when this iterator is destroyed. + // + // Note that unlike all of the preceding methods, this method is + // not abstract and therefore clients should not override it. + typedef void (*CleanupFunction)(void* arg1, void* arg2); + void RegisterCleanup(CleanupFunction function, void* arg1, void* arg2); + + private: + struct Cleanup { + CleanupFunction function; + void* arg1; + void* arg2; + Cleanup* next; + }; + Cleanup cleanup_; + + // No copying allowed + Iterator(const Iterator&); + void operator=(const Iterator&); +}; + +// Return an empty iterator (yields nothing). +extern Iterator* NewEmptyIterator(); + +// Return an empty iterator with the specified status. +extern Iterator* NewErrorIterator(const Status& status); + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_ITERATOR_H_ diff --git a/tensorflow/core/lib/io/match.cc b/tensorflow/core/lib/io/match.cc new file mode 100644 index 0000000000..1563642d0b --- /dev/null +++ b/tensorflow/core/lib/io/match.cc @@ -0,0 +1,31 @@ +#include "tensorflow/core/lib/io/match.h" +#include +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace io { + +Status GetMatchingFiles(Env* env, const string& pattern, + std::vector* results) { + results->clear(); + std::vector all_files; + string dir = Dirname(pattern).ToString(); + if (dir.empty()) dir = "."; + string basename_pattern = Basename(pattern).ToString(); + Status s = env->GetChildren(dir, &all_files); + if (!s.ok()) { + return s; + } + for (const auto& f : all_files) { + int flags = 0; + if (fnmatch(basename_pattern.c_str(), Basename(f).ToString().c_str(), + flags) == 0) { + results->push_back(JoinPath(dir, f)); + } + } + return Status::OK(); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/match.h b/tensorflow/core/lib/io/match.h new file mode 100644 index 0000000000..fd194178e7 --- /dev/null +++ b/tensorflow/core/lib/io/match.h @@ -0,0 +1,24 @@ +#ifndef TENSORFLOW_LIB_IO_MATCH_H_ +#define TENSORFLOW_LIB_IO_MATCH_H_ + +#include +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +class Env; +namespace io { + +// Given a pattern, return the set of files that match the pattern. +// Note that this routine only supports wildcard characters in the +// basename portion of the pattern, not in the directory portion. If +// successful, return Status::OK and store the matching files in +// "*results". Otherwise, return a non-OK status. +Status GetMatchingFiles(Env* env, const string& pattern, + std::vector* results); + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_MATCH_H_ diff --git a/tensorflow/core/lib/io/match_test.cc b/tensorflow/core/lib/io/match_test.cc new file mode 100644 index 0000000000..aaa56e4e7e --- /dev/null +++ b/tensorflow/core/lib/io/match_test.cc @@ -0,0 +1,51 @@ +#include +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/match.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/env.h" +#include + +namespace tensorflow { +namespace io { + +static string Match(Env* env, const string& suffix_pattern) { + std::vector results; + Status s = GetMatchingFiles(env, JoinPath(testing::TmpDir(), suffix_pattern), + &results); + if (!s.ok()) { + return s.ToString(); + } else { + string r; + std::sort(results.begin(), results.end()); + for (size_t i = 0; i < results.size(); i++) { + strings::StrAppend(&r, (i > 0) ? "," : "", Basename(results[i])); + } + return r; + } +} +TEST(GetMatchingFiles, Simple) { + Env* env = Env::Default(); + EXPECT_EQ(Match(env, "thereisnosuchfile"), ""); + EXPECT_EQ(Match(env, "thereisnosuchfile*"), ""); + + // Populate a few files + EXPECT_OK(WriteStringToFile(Env::Default(), + JoinPath(testing::TmpDir(), "match-00"), "")); + EXPECT_OK(WriteStringToFile(Env::Default(), + JoinPath(testing::TmpDir(), "match-0a"), "")); + EXPECT_OK(WriteStringToFile(Env::Default(), + JoinPath(testing::TmpDir(), "match-01"), "")); + EXPECT_OK(WriteStringToFile(Env::Default(), + JoinPath(testing::TmpDir(), "match-aaa"), "")); + + EXPECT_EQ(Match(env, "match-*"), "match-00,match-01,match-0a,match-aaa"); + EXPECT_EQ(Match(env, "match-0[0-9]"), "match-00,match-01"); + EXPECT_EQ(Match(env, "match-?[0-9]"), "match-00,match-01"); + EXPECT_EQ(Match(env, "match-?a*"), "match-0a,match-aaa"); + EXPECT_EQ(Match(env, "match-??"), "match-00,match-01,match-0a"); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/lib/io/path.cc new file mode 100644 index 0000000000..1359ded0f0 --- /dev/null +++ b/tensorflow/core/lib/io/path.cc @@ -0,0 +1,92 @@ +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace io { + +string JoinPath(StringPiece part1, StringPiece part2) { + string result; + + StringPiece paths[2] = {part1, part2}; + for (StringPiece path : paths) { + if (path.empty()) continue; + + if (result.empty()) { + result = path.ToString(); + continue; + } + + if (result[result.size() - 1] == '/') { + if (IsAbsolutePath(path)) { + strings::StrAppend(&result, path.substr(1)); + } else { + strings::StrAppend(&result, path); + } + } else { + if (IsAbsolutePath(path)) { + strings::StrAppend(&result, path); + } else { + strings::StrAppend(&result, "/", path); + } + } + } + + return result; +} + +namespace internal { + +// Return the parts of the path, split on the final "/". If there is no +// "/" in the path, the first part of the output is empty and the second +// is the input. If the only "/" in the path is the first character, it is +// the first part of the output. +std::pair SplitPath(StringPiece path) { + auto pos = path.rfind('/'); + + // Handle the case with no '/' in 'path'. + if (pos == StringPiece::npos) + return std::make_pair(StringPiece(path.data(), 0), path); + + // Handle the case with a single leading '/' in 'path'. + if (pos == 0) + return std::make_pair(StringPiece(path.data(), 1), + StringPiece(path.data() + 1, path.size() - 1)); + + return std::make_pair( + StringPiece(path.data(), pos), + StringPiece(path.data() + pos + 1, path.size() - (pos + 1))); +} + +// Return the parts of the basename of path, split on the final ".". +// If there is no "." in the basename or "." is the final character in the +// basename, the second value will be empty. +std::pair SplitBasename(StringPiece path) { + path = Basename(path); + + auto pos = path.rfind('.'); + if (pos == StringPiece::npos) + return std::make_pair(path, StringPiece(path.data() + path.size(), 0)); + return std::make_pair( + StringPiece(path.data(), pos), + StringPiece(path.data() + pos + 1, path.size() - (pos + 1))); +} +} // namespace internal + +bool IsAbsolutePath(StringPiece path) { + return !path.empty() && path[0] == '/'; +} + +StringPiece Dirname(StringPiece path) { + return internal::SplitPath(path).first; +} + +StringPiece Basename(StringPiece path) { + return internal::SplitPath(path).second; +} + +StringPiece Extension(StringPiece path) { + return internal::SplitBasename(path).second; +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h new file mode 100644 index 0000000000..01483f1702 --- /dev/null +++ b/tensorflow/core/lib/io/path.h @@ -0,0 +1,47 @@ +#ifndef TENSORFLOW_LIB_IO_PATH_H_ +#define TENSORFLOW_LIB_IO_PATH_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +class StringPiece; +namespace io { + +// Utility routines for processing filenames + +// Join multiple paths together, without introducing unnecessary path +// separators. +// For example: +// +// Arguments | JoinPath +// ---------------------------+---------- +// '/foo', 'bar' | /foo/bar +// '/foo/', 'bar' | /foo/bar +// '/foo', '/bar' | /foo/bar +// +// Usage: +// string path = io::JoinPath("/mydir", filename); +// string path = io::JoinPath(FLAGS_test_srcdir, filename); +string JoinPath(StringPiece part1, StringPiece part2); + +// Return true if path is absolute. +bool IsAbsolutePath(StringPiece path); + +// Returns the part of the path before the final "/". If there is a single +// leading "/" in the path, the result will be the leading "/". If there is +// no "/" in the path, the result is the empty prefix of the input. +StringPiece Dirname(StringPiece path); + +// Returns the part of the path after the final "/". If there is no +// "/" in the path, the result is the same as the input. +StringPiece Basename(StringPiece path); + +// Returns the part of the basename of path after the final ".". If +// there is no "." in the basename, the result is empty. +StringPiece Extension(StringPiece path); + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_PATH_H_ diff --git a/tensorflow/core/lib/io/path_test.cc b/tensorflow/core/lib/io/path_test.cc new file mode 100644 index 0000000000..b670e44f1f --- /dev/null +++ b/tensorflow/core/lib/io/path_test.cc @@ -0,0 +1,65 @@ +#include "tensorflow/core/lib/io/path.h" +#include + +namespace tensorflow { +namespace io { + +TEST(PathTest, JoinPath) { + EXPECT_EQ("/foo/bar", JoinPath("/foo", "bar")); + EXPECT_EQ("foo/bar", JoinPath("foo", "bar")); + EXPECT_EQ("foo/bar", JoinPath("foo", "/bar")); + EXPECT_EQ("/foo/bar", JoinPath("/foo", "/bar")); + + EXPECT_EQ("/bar", JoinPath("", "/bar")); + EXPECT_EQ("bar", JoinPath("", "bar")); + EXPECT_EQ("/foo", JoinPath("/foo", "")); + + EXPECT_EQ("/foo/bar/baz/blah/blink/biz", + JoinPath("/foo/bar/baz/", "/blah/blink/biz")); +} + +TEST(PathTest, IsAbsolutePath) { + EXPECT_FALSE(IsAbsolutePath("")); + EXPECT_FALSE(IsAbsolutePath("../foo")); + EXPECT_FALSE(IsAbsolutePath("foo")); + EXPECT_FALSE(IsAbsolutePath("./foo")); + EXPECT_FALSE(IsAbsolutePath("foo/bar/baz/")); + EXPECT_TRUE(IsAbsolutePath("/foo")); + EXPECT_TRUE(IsAbsolutePath("/foo/bar/../baz")); +} + +TEST(PathTest, Dirname) { + EXPECT_EQ("/hello", Dirname("/hello/")); + EXPECT_EQ("/", Dirname("/hello")); + EXPECT_EQ("hello", Dirname("hello/world")); + EXPECT_EQ("hello", Dirname("hello/")); + EXPECT_EQ("", Dirname("world")); + EXPECT_EQ("/", Dirname("/")); + EXPECT_EQ("", Dirname("")); +} + +TEST(PathTest, Basename) { + EXPECT_EQ("", Basename("/hello/")); + EXPECT_EQ("hello", Basename("/hello")); + EXPECT_EQ("world", Basename("hello/world")); + EXPECT_EQ("", Basename("hello/")); + EXPECT_EQ("world", Basename("world")); + EXPECT_EQ("", Basename("/")); + EXPECT_EQ("", Basename("")); +} + +TEST(PathTest, Extension) { + EXPECT_EQ("gif", Extension("foo.gif")); + EXPECT_EQ("", Extension("foo.")); + EXPECT_EQ("", Extension("")); + EXPECT_EQ("", Extension("/")); + EXPECT_EQ("", Extension("foo")); + EXPECT_EQ("", Extension("foo/")); + EXPECT_EQ("gif", Extension("/a/path/to/foo.gif")); + EXPECT_EQ("html", Extension("/a/path.bar/to/foo.html")); + EXPECT_EQ("", Extension("/a/path.bar/to/foo")); + EXPECT_EQ("baz", Extension("/a/path.bar/to/foo.bar.baz")); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc new file mode 100644 index 0000000000..2f0fabff63 --- /dev/null +++ b/tensorflow/core/lib/io/record_reader.cc @@ -0,0 +1,80 @@ +#include "tensorflow/core/lib/io/record_reader.h" + +#include +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace io { + +RecordReader::RecordReader(RandomAccessFile* file) : src_(file) {} + +RecordReader::~RecordReader() {} + +// Read n+4 bytes from file, verify that checksum of first n bytes is +// stored in the last 4 bytes and store the first n bytes in *result. +// May use *storage as backing store. +static Status ReadChecksummed(RandomAccessFile* file, uint64 offset, + size_t n, StringPiece* result, + string* storage) { + if (n >= SIZE_MAX - sizeof(uint32)) { + return errors::DataLoss("record size too large"); + } + + const size_t expected = n + sizeof(uint32); + storage->resize(expected); + StringPiece data; + Status s = file->Read(offset, expected, &data, &(*storage)[0]); + if (!s.ok()) { + return s; + } + if (data.size() != expected) { + if (data.size() == 0) { + return errors::OutOfRange("eof"); + } else { + return errors::DataLoss("truncated record at ", offset); + } + } + uint32 masked_crc = core::DecodeFixed32(data.data() + n); + if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) { + return errors::DataLoss("corrupted record at ", offset); + } + *result = StringPiece(data.data(), n); + return Status::OK(); +} + +Status RecordReader::ReadRecord(uint64* offset, string* record) { + static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); + static const size_t kFooterSize = sizeof(uint32); + + // Read length + StringPiece lbuf; + Status s = ReadChecksummed(src_, *offset, sizeof(uint64), &lbuf, record); + if (!s.ok()) { + return s; + } + const uint64 length = core::DecodeFixed64(lbuf.data()); + + // Read data + StringPiece data; + s = ReadChecksummed(src_, *offset + kHeaderSize, length, &data, record); + if (!s.ok()) { + if (errors::IsOutOfRange(s)) { + s = errors::DataLoss("truncated record at ", *offset); + } + return s; + } + if (record->data() != data.data()) { + // RandomAccessFile placed the data in some other location. + memmove(&(*record)[0], data.data(), data.size()); + } + + record->resize(data.size()); + *offset += kHeaderSize + length + kFooterSize; + return Status::OK(); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h new file mode 100644 index 0000000000..a8c1b0dd5d --- /dev/null +++ b/tensorflow/core/lib/io/record_reader.h @@ -0,0 +1,36 @@ +#ifndef TENSORFLOW_LIB_IO_RECORD_READER_H_ +#define TENSORFLOW_LIB_IO_RECORD_READER_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class RandomAccessFile; + +namespace io { + +class RecordReader { + public: + // Create a reader that will return log records from "*file". + // "*file" must remain live while this Reader is in use. + explicit RecordReader(RandomAccessFile* file); + + ~RecordReader(); + + // Read the record at "*offset" into *record and update *offset to + // point to the offset of the next record. Returns OK on success, + // OUT_OF_RANGE for end of file, or something else for an error. + Status ReadRecord(uint64* offset, string* record); + + private: + RandomAccessFile* src_; + + TF_DISALLOW_COPY_AND_ASSIGN(RecordReader); +}; + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_RECORD_READER_H_ diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc new file mode 100644 index 0000000000..3d7f1509ab --- /dev/null +++ b/tensorflow/core/lib/io/record_writer.cc @@ -0,0 +1,42 @@ +#include "tensorflow/core/lib/io/record_writer.h" + +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/hash/crc32c.h" + +namespace tensorflow { +namespace io { + +RecordWriter::RecordWriter(WritableFile* dest) : dest_(dest) {} + +RecordWriter::~RecordWriter() {} + +static uint32 MaskedCrc(const char* data, size_t n) { + return crc32c::Mask(crc32c::Value(data, n)); +} + +Status RecordWriter::WriteRecord(StringPiece data) { + // Format of a single record: + // uint64 length + // uint32 masked crc of length + // byte data[length] + // uint32 masked crc of data + char header[sizeof(uint64) + sizeof(uint32)]; + core::EncodeFixed64(header + 0, data.size()); + core::EncodeFixed32(header + sizeof(uint64), + MaskedCrc(header, sizeof(uint64))); + Status s = dest_->Append(StringPiece(header, sizeof(header))); + if (!s.ok()) { + return s; + } + s = dest_->Append(data); + if (!s.ok()) { + return s; + } + char footer[sizeof(uint32)]; + core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size())); + return dest_->Append(StringPiece(footer, sizeof(footer))); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h new file mode 100644 index 0000000000..c7af00e5ae --- /dev/null +++ b/tensorflow/core/lib/io/record_writer.h @@ -0,0 +1,34 @@ +#ifndef TENSORFLOW_LIB_IO_RECORD_WRITER_H_ +#define TENSORFLOW_LIB_IO_RECORD_WRITER_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class WritableFile; + +namespace io { + +class RecordWriter { + public: + // Create a writer that will append data to "*dest". + // "*dest" must be initially empty. + // "*dest" must remain live while this Writer is in use. + explicit RecordWriter(WritableFile* dest); + + ~RecordWriter(); + + Status WriteRecord(StringPiece slice); + + private: + WritableFile* const dest_; + + TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter); +}; + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_RECORD_WRITER_H_ diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc new file mode 100644 index 0000000000..3e9c816443 --- /dev/null +++ b/tensorflow/core/lib/io/recordio_test.cc @@ -0,0 +1,245 @@ +#include "tensorflow/core/lib/io/record_reader.h" +#include +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace io { + +// Construct a string of the specified length made out of the supplied +// partial string. +static string BigString(const string& partial_string, size_t n) { + string result; + while (result.size() < n) { + result.append(partial_string); + } + result.resize(n); + return result; +} + +// Construct a string from a number +static string NumberString(int n) { + char buf[50]; + snprintf(buf, sizeof(buf), "%d.", n); + return string(buf); +} + +// Return a skewed potentially long string +static string RandomSkewedString(int i, random::SimplePhilox* rnd) { + return BigString(NumberString(i), rnd->Skewed(17)); +} + +class RecordioTest : public testing::Test { + private: + class StringDest : public WritableFile { + public: + string contents_; + + Status Close() override { return Status::OK(); } + Status Flush() override { return Status::OK(); } + Status Sync() override { return Status::OK(); } + Status Append(const StringPiece& slice) override { + contents_.append(slice.data(), slice.size()); + return Status::OK(); + } + }; + + class StringSource : public RandomAccessFile { + public: + StringPiece contents_; + mutable bool force_error_; + mutable bool returned_partial_; + StringSource() : force_error_(false), returned_partial_(false) {} + + Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override { + EXPECT_FALSE(returned_partial_) << "must not Read() after eof/error"; + + if (force_error_) { + force_error_ = false; + returned_partial_ = true; + return errors::DataLoss("read error"); + } + + if (offset >= contents_.size()) { + return errors::OutOfRange("end of file"); + } + + if (contents_.size() < offset + n) { + n = contents_.size() - offset; + returned_partial_ = true; + } + *result = StringPiece(contents_.data() + offset, n); + return Status::OK(); + } + }; + + StringDest dest_; + StringSource source_; + bool reading_; + uint64 readpos_; + RecordWriter* writer_; + RecordReader* reader_; + + public: + RecordioTest() + : reading_(false), + readpos_(0), + writer_(new RecordWriter(&dest_)), + reader_(new RecordReader(&source_)) {} + + ~RecordioTest() override { + delete writer_; + delete reader_; + } + + void Write(const string& msg) { + ASSERT_TRUE(!reading_) << "Write() after starting to read"; + ASSERT_OK(writer_->WriteRecord(StringPiece(msg))); + } + + size_t WrittenBytes() const { return dest_.contents_.size(); } + + string Read() { + if (!reading_) { + reading_ = true; + source_.contents_ = StringPiece(dest_.contents_); + } + string record; + Status s = reader_->ReadRecord(&readpos_, &record); + if (s.ok()) { + return record; + } else if (errors::IsOutOfRange(s)) { + return "EOF"; + } else { + return s.ToString(); + } + } + + void IncrementByte(int offset, int delta) { + dest_.contents_[offset] += delta; + } + + void SetByte(int offset, char new_byte) { + dest_.contents_[offset] = new_byte; + } + + void ShrinkSize(int bytes) { + dest_.contents_.resize(dest_.contents_.size() - bytes); + } + + void FixChecksum(int header_offset, int len) { + // Compute crc of type/len/data + uint32_t crc = crc32c::Value(&dest_.contents_[header_offset + 6], 1 + len); + crc = crc32c::Mask(crc); + core::EncodeFixed32(&dest_.contents_[header_offset], crc); + } + + void ForceError() { source_.force_error_ = true; } + + void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; } + + void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) { + Write("foo"); + Write("bar"); + Write(BigString("x", 10000)); + reading_ = true; + source_.contents_ = StringPiece(dest_.contents_); + uint64 offset = WrittenBytes() + offset_past_end; + string record; + Status s = reader_->ReadRecord(&offset, &record); + ASSERT_TRUE(errors::IsOutOfRange(s)) << s; + } +}; + +TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); } + +TEST_F(RecordioTest, ReadWrite) { + Write("foo"); + Write("bar"); + Write(""); + Write("xxxx"); + ASSERT_EQ("foo", Read()); + ASSERT_EQ("bar", Read()); + ASSERT_EQ("", Read()); + ASSERT_EQ("xxxx", Read()); + ASSERT_EQ("EOF", Read()); + ASSERT_EQ("EOF", Read()); // Make sure reads at eof work +} + +TEST_F(RecordioTest, ManyRecords) { + for (int i = 0; i < 100000; i++) { + Write(NumberString(i)); + } + for (int i = 0; i < 100000; i++) { + ASSERT_EQ(NumberString(i), Read()); + } + ASSERT_EQ("EOF", Read()); +} + +TEST_F(RecordioTest, RandomRead) { + const int N = 500; + { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int i = 0; i < N; i++) { + Write(RandomSkewedString(i, &rnd)); + } + } + { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int i = 0; i < N; i++) { + ASSERT_EQ(RandomSkewedString(i, &rnd), Read()); + } + } + ASSERT_EQ("EOF", Read()); +} + +// Tests of all the error paths in log_reader.cc follow: +static void AssertHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain " + << expected; +} + +TEST_F(RecordioTest, ReadError) { + Write("foo"); + ForceError(); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptLength) { + Write("foo"); + IncrementByte(6, 100); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptLengthCrc) { + Write("foo"); + IncrementByte(10, 100); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptData) { + Write("foo"); + IncrementByte(14, 10); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptDataCrc) { + Write("foo"); + IncrementByte(WrittenBytes() - 1, 10); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); } + +TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); } + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/table.cc b/tensorflow/core/lib/io/table.cc new file mode 100644 index 0000000000..769d7e72a5 --- /dev/null +++ b/tensorflow/core/lib/io/table.cc @@ -0,0 +1,169 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/table.h" + +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/io/table_options.h" +#include "tensorflow/core/lib/io/two_level_iterator.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace table { + +struct Table::Rep { + ~Rep() { delete index_block; } + + Options options; + Status status; + RandomAccessFile* file; + // XXX uint64 cache_id; + + BlockHandle metaindex_handle; // Handle to metaindex_block: saved from footer + Block* index_block; +}; + +Status Table::Open(const Options& options, RandomAccessFile* file, + uint64 size, Table** table) { + *table = NULL; + if (size < Footer::kEncodedLength) { + return errors::DataLoss("file is too short to be an sstable"); + } + + char footer_space[Footer::kEncodedLength]; + StringPiece footer_input; + Status s = + file->Read(size - Footer::kEncodedLength, Footer::kEncodedLength, + &footer_input, footer_space); + if (!s.ok()) return s; + + Footer footer; + s = footer.DecodeFrom(&footer_input); + if (!s.ok()) return s; + + // Read the index block + BlockContents contents; + Block* index_block = NULL; + if (s.ok()) { + s = ReadBlock(file, footer.index_handle(), &contents); + if (s.ok()) { + index_block = new Block(contents); + } + } + + if (s.ok()) { + // We've successfully read the footer and the index block: we're + // ready to serve requests. + Rep* rep = new Table::Rep; + rep->options = options; + rep->file = file; + rep->metaindex_handle = footer.metaindex_handle(); + rep->index_block = index_block; + // XXX rep->cache_id = (options.block_cache ? + // options.block_cache->NewId() : 0); + *table = new Table(rep); + } else { + if (index_block) delete index_block; + } + + return s; +} + +Table::~Table() { delete rep_; } + +static void DeleteBlock(void* arg, void* ignored) { + delete reinterpret_cast(arg); +} + +// Convert an index iterator value (i.e., an encoded BlockHandle) +// into an iterator over the contents of the corresponding block. +Iterator* Table::BlockReader(void* arg, const StringPiece& index_value) { + Table* table = reinterpret_cast(arg); + // Cache* block_cache = table->rep_->options.block_cache; + Block* block = NULL; + // Cache::Handle* cache_handle = NULL; + + BlockHandle handle; + StringPiece input = index_value; + Status s = handle.DecodeFrom(&input); + // We intentionally allow extra stuff in index_value so that we + // can add more features in the future. + + if (s.ok()) { + BlockContents contents; + s = ReadBlock(table->rep_->file, handle, &contents); + if (s.ok()) { + block = new Block(contents); + } + } + + Iterator* iter; + if (block != NULL) { + iter = block->NewIterator(); + iter->RegisterCleanup(&DeleteBlock, block, NULL); + } else { + iter = NewErrorIterator(s); + } + return iter; +} + +Iterator* Table::NewIterator() const { + return NewTwoLevelIterator(rep_->index_block->NewIterator(), + &Table::BlockReader, const_cast(this)); +} + +Status Table::InternalGet(const StringPiece& k, void* arg, + void (*saver)(void*, const StringPiece&, + const StringPiece&)) { + Status s; + Iterator* iiter = rep_->index_block->NewIterator(); + iiter->Seek(k); + if (iiter->Valid()) { + BlockHandle handle; + Iterator* block_iter = BlockReader(this, iiter->value()); + block_iter->Seek(k); + if (block_iter->Valid()) { + (*saver)(arg, block_iter->key(), block_iter->value()); + } + s = block_iter->status(); + delete block_iter; + } + if (s.ok()) { + s = iiter->status(); + } + delete iiter; + return s; +} + +uint64 Table::ApproximateOffsetOf(const StringPiece& key) const { + Iterator* index_iter = rep_->index_block->NewIterator(); + index_iter->Seek(key); + uint64 result; + if (index_iter->Valid()) { + BlockHandle handle; + StringPiece input = index_iter->value(); + Status s = handle.DecodeFrom(&input); + if (s.ok()) { + result = handle.offset(); + } else { + // Strange: we can't decode the block handle in the index block. + // We'll just return the offset of the metaindex block, which is + // close to the whole file size for this case. + result = rep_->metaindex_handle.offset(); + } + } else { + // key is past the last key in the file. Approximate the offset + // by returning the offset of the metaindex block (which is + // right near the end of the file). + result = rep_->metaindex_handle.offset(); + } + delete index_iter; + return result; +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/table.h b/tensorflow/core/lib/io/table.h new file mode 100644 index 0000000000..230dded2d4 --- /dev/null +++ b/tensorflow/core/lib/io/table.h @@ -0,0 +1,76 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_TABLE_H_ +#define TENSORFLOW_LIB_IO_TABLE_H_ + +#include +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +class RandomAccessFile; + +namespace table { + +class Block; +class BlockHandle; +class Footer; +struct Options; + +// A Table is a sorted map from strings to strings. Tables are +// immutable and persistent. A Table may be safely accessed from +// multiple threads without external synchronization. +class Table { + public: + // Attempt to open the table that is stored in bytes [0..file_size) + // of "file", and read the metadata entries necessary to allow + // retrieving data from the table. + // + // If successful, returns ok and sets "*table" to the newly opened + // table. The client should delete "*table" when no longer needed. + // If there was an error while initializing the table, sets "*table" + // to NULL and returns a non-ok status. Does not take ownership of + // "*file", but the client must ensure that "file" remains live + // for the duration of the returned table's lifetime. + static Status Open(const Options& options, RandomAccessFile* file, + uint64 file_size, Table** table); + + ~Table(); + + // Returns a new iterator over the table contents. + // The result of NewIterator() is initially invalid (caller must + // call one of the Seek methods on the iterator before using it). + Iterator* NewIterator() const; + + // Given a key, return an approximate byte offset in the file where + // the data for that key begins (or would begin if the key were + // present in the file). The returned value is in terms of file + // bytes, and so includes effects like compression of the underlying data. + // E.g., the approximate offset of the last key in the table will + // be close to the file length. + uint64 ApproximateOffsetOf(const StringPiece& key) const; + + private: + struct Rep; + Rep* rep_; + + explicit Table(Rep* rep) { rep_ = rep; } + static Iterator* BlockReader(void*, const StringPiece&); + + // Calls (*handle_result)(arg, ...) with the entry found after a call + // to Seek(key). May not make such a call if filter policy says + // that key is not present. + Status InternalGet(const StringPiece& key, void* arg, + void (*handle_result)(void* arg, const StringPiece& k, + const StringPiece& v)); + + // No copying allowed + Table(const Table&); + void operator=(const Table&); +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_TABLE_H_ diff --git a/tensorflow/core/lib/io/table_builder.cc b/tensorflow/core/lib/io/table_builder.cc new file mode 100644 index 0000000000..b786888b30 --- /dev/null +++ b/tensorflow/core/lib/io/table_builder.cc @@ -0,0 +1,263 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/table_builder.h" + +#include +#include "tensorflow/core/lib/io/block_builder.h" +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/io/table_options.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace table { + +namespace { + +void FindShortestSeparator(string* start, const StringPiece& limit) { + // Find length of common prefix + size_t min_length = std::min(start->size(), limit.size()); + size_t diff_index = 0; + while ((diff_index < min_length) && + ((*start)[diff_index] == limit[diff_index])) { + diff_index++; + } + + if (diff_index >= min_length) { + // Do not shorten if one string is a prefix of the other + } else { + uint8 diff_byte = static_cast((*start)[diff_index]); + if (diff_byte < static_cast(0xff) && + diff_byte + 1 < static_cast(limit[diff_index])) { + (*start)[diff_index]++; + start->resize(diff_index + 1); + assert(StringPiece(*start).compare(limit) < 0); + } + } +} + +void FindShortSuccessor(string* key) { + // Find first character that can be incremented + size_t n = key->size(); + for (size_t i = 0; i < n; i++) { + const uint8 byte = (*key)[i]; + if (byte != static_cast(0xff)) { + (*key)[i] = byte + 1; + key->resize(i + 1); + return; + } + } + // *key is a run of 0xffs. Leave it alone. +} +} // namespace + +struct TableBuilder::Rep { + Options options; + Options index_block_options; + WritableFile* file; + uint64 offset; + Status status; + BlockBuilder data_block; + BlockBuilder index_block; + string last_key; + int64 num_entries; + bool closed; // Either Finish() or Abandon() has been called. + + // We do not emit the index entry for a block until we have seen the + // first key for the next data block. This allows us to use shorter + // keys in the index block. For example, consider a block boundary + // between the keys "the quick brown fox" and "the who". We can use + // "the r" as the key for the index block entry since it is >= all + // entries in the first block and < all entries in subsequent + // blocks. + // + // Invariant: r->pending_index_entry is true only if data_block is empty. + bool pending_index_entry; + BlockHandle pending_handle; // Handle to add to index block + + string compressed_output; + + Rep(const Options& opt, WritableFile* f) + : options(opt), + index_block_options(opt), + file(f), + offset(0), + data_block(&options), + index_block(&index_block_options), + num_entries(0), + closed(false), + pending_index_entry(false) { + index_block_options.block_restart_interval = 1; + } +}; + +TableBuilder::TableBuilder(const Options& options, WritableFile* file) + : rep_(new Rep(options, file)) {} + +TableBuilder::~TableBuilder() { + assert(rep_->closed); // Catch errors where caller forgot to call Finish() + delete rep_; +} + +void TableBuilder::Add(const StringPiece& key, const StringPiece& value) { + Rep* r = rep_; + assert(!r->closed); + if (!ok()) return; + if (r->num_entries > 0) { + assert(key.compare(StringPiece(r->last_key)) > 0); + // See if this key+value would make our current block overly large. If + // so, emit the current block before adding this key/value + const int kOverlyLargeBlockRatio = 2; + const size_t this_entry_bytes = key.size() + value.size(); + if (this_entry_bytes >= kOverlyLargeBlockRatio * r->options.block_size) { + Flush(); + } + } + + if (r->pending_index_entry) { + assert(r->data_block.empty()); + FindShortestSeparator(&r->last_key, key); + string handle_encoding; + r->pending_handle.EncodeTo(&handle_encoding); + r->index_block.Add(r->last_key, StringPiece(handle_encoding)); + r->pending_index_entry = false; + } + + r->last_key.assign(key.data(), key.size()); + r->num_entries++; + r->data_block.Add(key, value); + + const size_t estimated_block_size = r->data_block.CurrentSizeEstimate(); + if (estimated_block_size >= r->options.block_size) { + Flush(); + } +} + +void TableBuilder::Flush() { + Rep* r = rep_; + assert(!r->closed); + if (!ok()) return; + if (r->data_block.empty()) return; + assert(!r->pending_index_entry); + WriteBlock(&r->data_block, &r->pending_handle); + if (ok()) { + r->pending_index_entry = true; + r->status = r->file->Flush(); + } +} + +void TableBuilder::WriteBlock(BlockBuilder* block, BlockHandle* handle) { + // File format contains a sequence of blocks where each block has: + // block_data: uint8[n] + // type: uint8 + // crc: uint32 + assert(ok()); + Rep* r = rep_; + StringPiece raw = block->Finish(); + + StringPiece block_contents; + CompressionType type = r->options.compression; + // TODO(postrelease): Support more compression options: zlib? + switch (type) { + case kNoCompression: + block_contents = raw; + break; + + case kSnappyCompression: { + string* compressed = &r->compressed_output; + if (port::Snappy_Compress(raw.data(), raw.size(), compressed) && + compressed->size() < raw.size() - (raw.size() / 8u)) { + block_contents = *compressed; + } else { + // Snappy not supported, or compressed less than 12.5%, so just + // store uncompressed form + block_contents = raw; + type = kNoCompression; + } + break; + } + } + WriteRawBlock(block_contents, type, handle); + r->compressed_output.clear(); + block->Reset(); +} + +void TableBuilder::WriteRawBlock(const StringPiece& block_contents, + CompressionType type, BlockHandle* handle) { + Rep* r = rep_; + handle->set_offset(r->offset); + handle->set_size(block_contents.size()); + r->status = r->file->Append(block_contents); + if (r->status.ok()) { + char trailer[kBlockTrailerSize]; + trailer[0] = type; + uint32 crc = crc32c::Value(block_contents.data(), block_contents.size()); + crc = crc32c::Extend(crc, trailer, 1); // Extend crc to cover block type + core::EncodeFixed32(trailer + 1, crc32c::Mask(crc)); + r->status = r->file->Append(StringPiece(trailer, kBlockTrailerSize)); + if (r->status.ok()) { + r->offset += block_contents.size() + kBlockTrailerSize; + } + } +} + +Status TableBuilder::status() const { return rep_->status; } + +Status TableBuilder::Finish() { + Rep* r = rep_; + Flush(); + assert(!r->closed); + r->closed = true; + + BlockHandle metaindex_block_handle, index_block_handle; + + // Write metaindex block + if (ok()) { + BlockBuilder meta_index_block(&r->options); + // TODO(postrelease): Add stats and other meta blocks + WriteBlock(&meta_index_block, &metaindex_block_handle); + } + + // Write index block + if (ok()) { + if (r->pending_index_entry) { + FindShortSuccessor(&r->last_key); + string handle_encoding; + r->pending_handle.EncodeTo(&handle_encoding); + r->index_block.Add(r->last_key, StringPiece(handle_encoding)); + r->pending_index_entry = false; + } + WriteBlock(&r->index_block, &index_block_handle); + } + + // Write footer + if (ok()) { + Footer footer; + footer.set_metaindex_handle(metaindex_block_handle); + footer.set_index_handle(index_block_handle); + string footer_encoding; + footer.EncodeTo(&footer_encoding); + r->status = r->file->Append(footer_encoding); + if (r->status.ok()) { + r->offset += footer_encoding.size(); + } + } + return r->status; +} + +void TableBuilder::Abandon() { + Rep* r = rep_; + assert(!r->closed); + r->closed = true; +} + +uint64 TableBuilder::NumEntries() const { return rep_->num_entries; } + +uint64 TableBuilder::FileSize() const { return rep_->offset; } + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/table_builder.h b/tensorflow/core/lib/io/table_builder.h new file mode 100644 index 0000000000..cebf4d8e0c --- /dev/null +++ b/tensorflow/core/lib/io/table_builder.h @@ -0,0 +1,87 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// TableBuilder provides the interface used to build a Table +// (an immutable and sorted map from keys to values). +// +// Multiple threads can invoke const methods on a TableBuilder without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same TableBuilder must use +// external synchronization. + +#ifndef TENSORFLOW_LIB_IO_TABLE_BUILDER_H_ +#define TENSORFLOW_LIB_IO_TABLE_BUILDER_H_ + +#include +#include "tensorflow/core/lib/io/table_options.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +class WritableFile; +namespace table { + +class BlockBuilder; +class BlockHandle; + +class TableBuilder { + public: + // Create a builder that will store the contents of the table it is + // building in *file. Does not close the file. It is up to the + // caller to close the file after calling Finish(). + TableBuilder(const Options& options, WritableFile* file); + + // REQUIRES: Either Finish() or Abandon() has been called. + ~TableBuilder(); + + // Add key,value to the table being constructed. + // REQUIRES: key is after any previously added key in lexicographic order. + // REQUIRES: Finish(), Abandon() have not been called + void Add(const StringPiece& key, const StringPiece& value); + + // Advanced operation: flush any buffered key/value pairs to file. + // Can be used to ensure that two adjacent entries never live in + // the same data block. Most clients should not need to use this method. + // REQUIRES: Finish(), Abandon() have not been called + void Flush(); + + // Return non-ok iff some error has been detected. + Status status() const; + + // Finish building the table. Stops using the file passed to the + // constructor after this function returns. + // REQUIRES: Finish(), Abandon() have not been called + Status Finish(); + + // Indicate that the contents of this builder should be abandoned. Stops + // using the file passed to the constructor after this function returns. + // If the caller is not going to call Finish(), it must call Abandon() + // before destroying this builder. + // REQUIRES: Finish(), Abandon() have not been called + void Abandon(); + + // Number of calls to Add() so far. + uint64 NumEntries() const; + + // Size of the file generated so far. If invoked after a successful + // Finish() call, returns the size of the final generated file. + uint64 FileSize() const; + + private: + bool ok() const { return status().ok(); } + void WriteBlock(BlockBuilder* block, BlockHandle* handle); + void WriteRawBlock(const StringPiece& data, CompressionType, + BlockHandle* handle); + + struct Rep; + Rep* rep_; + + // No copying allowed + TableBuilder(const TableBuilder&); + void operator=(const TableBuilder&); +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_TABLE_BUILDER_H_ diff --git a/tensorflow/core/lib/io/table_format.txt b/tensorflow/core/lib/io/table_format.txt new file mode 100644 index 0000000000..7edb9fb121 --- /dev/null +++ b/tensorflow/core/lib/io/table_format.txt @@ -0,0 +1,8 @@ +File format +=========== + +The table format is heavily based on the table format for the LevelDB +open source key/value store, with the exception that our tables +do not support "filter" meta blocks (Bloom Filters). See: + +https://code.google.com/p/leveldb/source/browse/doc/table_format.txt diff --git a/tensorflow/core/lib/io/table_options.h b/tensorflow/core/lib/io/table_options.h new file mode 100644 index 0000000000..45b061b03b --- /dev/null +++ b/tensorflow/core/lib/io/table_options.h @@ -0,0 +1,53 @@ +#ifndef TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_ +#define TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_ + +#include + +namespace tensorflow { +namespace table { + +// DB contents are stored in a set of blocks, each of which holds a +// sequence of key,value pairs. Each block may be compressed before +// being stored in a file. The following enum describes which +// compression method (if any) is used to compress a block. +enum CompressionType { + // NOTE: do not change the values of existing entries, as these are + // part of the persistent format on disk. + kNoCompression = 0x0, + kSnappyCompression = 0x1 +}; + +// Options to control the behavior of a table (passed to Table::Open) +struct Options { + // Approximate size of user data packed per block. Note that the + // block size specified here corresponds to uncompressed data. The + // actual size of the unit read from disk may be smaller if + // compression is enabled. This parameter can be changed dynamically. + size_t block_size = 262144; + + // Number of keys between restart points for delta encoding of keys. + // This parameter can be changed dynamically. Most clients should + // leave this parameter alone. + int block_restart_interval = 16; + + // Compress blocks using the specified compression algorithm. This + // parameter can be changed dynamically. + // + // Default: kSnappyCompression, which gives lightweight but fast + // compression. + // + // Typical speeds of kSnappyCompression on an Intel(R) Core(TM)2 2.4GHz: + // ~200-500MB/s compression + // ~400-800MB/s decompression + // Note that these speeds are significantly faster than most + // persistent storage speeds, and therefore it is typically never + // worth switching to kNoCompression. Even if the input data is + // incompressible, the kSnappyCompression implementation will + // efficiently detect that and will switch to uncompressed mode. + CompressionType compression = kSnappyCompression; +}; + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_ diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc new file mode 100644 index 0000000000..66e90ac64e --- /dev/null +++ b/tensorflow/core/lib/io/table_test.cc @@ -0,0 +1,601 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/table.h" + +#include +#include +#include +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/io/block_builder.h" +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/io/iterator.h" +#include "tensorflow/core/lib/io/table_builder.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace table { + +namespace test { +static StringPiece RandomString(random::SimplePhilox* rnd, int len, + string* dst) { + dst->resize(len); + for (int i = 0; i < len; i++) { + (*dst)[i] = static_cast(' ' + rnd->Uniform(95)); // ' ' .. '~' + } + return StringPiece(*dst); +} +static string RandomKey(random::SimplePhilox* rnd, int len) { + // Make sure to generate a wide variety of characters so we + // test the boundary conditions for short-key optimizations. + static const char kTestChars[] = {'\0', '\1', 'a', 'b', 'c', + 'd', 'e', '\xfd', '\xfe', '\xff'}; + string result; + for (int i = 0; i < len; i++) { + result += kTestChars[rnd->Uniform(sizeof(kTestChars))]; + } + return result; +} +static StringPiece CompressibleString(random::SimplePhilox* rnd, + double compressed_fraction, size_t len, + string* dst) { + int raw = static_cast(len * compressed_fraction); + if (raw < 1) raw = 1; + string raw_data; + RandomString(rnd, raw, &raw_data); + + // Duplicate the random data until we have filled "len" bytes + dst->clear(); + while (dst->size() < len) { + dst->append(raw_data); + } + dst->resize(len); + return StringPiece(*dst); +} +} + +static void Increment(string* key) { key->push_back('\0'); } + +// An STL comparator that compares two StringPieces +namespace { +struct STLLessThan { + STLLessThan() {} + bool operator()(const string& a, const string& b) const { + return StringPiece(a).compare(StringPiece(b)) < 0; + } +}; +} // namespace + +class StringSink : public WritableFile { + public: + ~StringSink() {} + + const string& contents() const { return contents_; } + + virtual Status Close() { return Status::OK(); } + virtual Status Flush() { return Status::OK(); } + virtual Status Sync() { return Status::OK(); } + + virtual Status Append(const StringPiece& data) { + contents_.append(data.data(), data.size()); + return Status::OK(); + } + + private: + string contents_; +}; + +class StringSource : public RandomAccessFile { + public: + StringSource(const StringPiece& contents) + : contents_(contents.data(), contents.size()), bytes_read_(0) {} + + virtual ~StringSource() {} + + uint64 Size() const { return contents_.size(); } + + virtual Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const { + if (offset > contents_.size()) { + return errors::InvalidArgument("invalid Read offset"); + } + if (offset + n > contents_.size()) { + n = contents_.size() - offset; + } + memcpy(scratch, &contents_[offset], n); + *result = StringPiece(scratch, n); + bytes_read_ += n; + return Status::OK(); + } + + uint64 BytesRead() const { return bytes_read_; } + + private: + string contents_; + mutable uint64 bytes_read_; +}; + +typedef std::map KVMap; + +// Helper class for tests to unify the interface between +// BlockBuilder/TableBuilder and Block/Table. +class Constructor { + public: + explicit Constructor() : data_(STLLessThan()) {} + virtual ~Constructor() {} + + void Add(const string& key, const StringPiece& value) { + data_[key] = value.ToString(); + } + + // Finish constructing the data structure with all the keys that have + // been added so far. Returns the keys in sorted order in "*keys" + // and stores the key/value pairs in "*kvmap" + void Finish(const Options& options, std::vector* keys, KVMap* kvmap) { + *kvmap = data_; + keys->clear(); + for (KVMap::const_iterator it = data_.begin(); it != data_.end(); ++it) { + keys->push_back(it->first); + } + data_.clear(); + Status s = FinishImpl(options, *kvmap); + ASSERT_TRUE(s.ok()) << s.ToString(); + } + + // Construct the data structure from the data in "data" + virtual Status FinishImpl(const Options& options, const KVMap& data) = 0; + + virtual Iterator* NewIterator() const = 0; + + virtual const KVMap& data() { return data_; } + + private: + KVMap data_; +}; + +class BlockConstructor : public Constructor { + public: + BlockConstructor() : block_(NULL) {} + ~BlockConstructor() { delete block_; } + virtual Status FinishImpl(const Options& options, const KVMap& data) { + delete block_; + block_ = NULL; + BlockBuilder builder(&options); + + for (KVMap::const_iterator it = data.begin(); it != data.end(); ++it) { + builder.Add(it->first, it->second); + } + // Open the block + data_ = builder.Finish().ToString(); + BlockContents contents; + contents.data = data_; + contents.cachable = false; + contents.heap_allocated = false; + block_ = new Block(contents); + return Status::OK(); + } + virtual Iterator* NewIterator() const { return block_->NewIterator(); } + + private: + string data_; + Block* block_; +}; + +class TableConstructor : public Constructor { + public: + TableConstructor() : source_(NULL), table_(NULL) {} + ~TableConstructor() { Reset(); } + virtual Status FinishImpl(const Options& options, const KVMap& data) { + Reset(); + StringSink sink; + TableBuilder builder(options, &sink); + + for (KVMap::const_iterator it = data.begin(); it != data.end(); ++it) { + builder.Add(it->first, it->second); + TF_CHECK_OK(builder.status()); + } + Status s = builder.Finish(); + TF_CHECK_OK(s) << s.ToString(); + + CHECK_EQ(sink.contents().size(), builder.FileSize()); + + // Open the table + source_ = new StringSource(sink.contents()); + Options table_options; + return Table::Open(table_options, source_, sink.contents().size(), &table_); + } + + virtual Iterator* NewIterator() const { return table_->NewIterator(); } + + uint64 ApproximateOffsetOf(const StringPiece& key) const { + return table_->ApproximateOffsetOf(key); + } + + uint64 BytesRead() const { return source_->BytesRead(); } + + private: + void Reset() { + delete table_; + delete source_; + table_ = NULL; + source_ = NULL; + } + + StringSource* source_; + Table* table_; +}; + +enum TestType { TABLE_TEST, BLOCK_TEST }; + +struct TestArgs { + TestType type; + int restart_interval; +}; + +static const TestArgs kTestArgList[] = { + {TABLE_TEST, 16}, {TABLE_TEST, 1}, {TABLE_TEST, 1024}, + {BLOCK_TEST, 16}, {BLOCK_TEST, 1}, {BLOCK_TEST, 1024}, +}; +static const int kNumTestArgs = sizeof(kTestArgList) / sizeof(kTestArgList[0]); + +class Harness : public ::testing::Test { + public: + Harness() : constructor_(NULL) {} + + void Init(const TestArgs& args) { + delete constructor_; + constructor_ = NULL; + options_ = Options(); + + options_.block_restart_interval = args.restart_interval; + // Use shorter block size for tests to exercise block boundary + // conditions more. + options_.block_size = 256; + switch (args.type) { + case TABLE_TEST: + constructor_ = new TableConstructor(); + break; + case BLOCK_TEST: + constructor_ = new BlockConstructor(); + break; + } + } + + ~Harness() { delete constructor_; } + + void Add(const string& key, const string& value) { + constructor_->Add(key, value); + } + + void Test(random::SimplePhilox* rnd) { + std::vector keys; + KVMap data; + constructor_->Finish(options_, &keys, &data); + + TestForwardScan(keys, data); + TestRandomAccess(rnd, keys, data); + } + + void TestForwardScan(const std::vector& keys, const KVMap& data) { + Iterator* iter = constructor_->NewIterator(); + ASSERT_TRUE(!iter->Valid()); + iter->SeekToFirst(); + for (KVMap::const_iterator model_iter = data.begin(); + model_iter != data.end(); ++model_iter) { + ASSERT_EQ(ToString(data, model_iter), ToString(iter)); + iter->Next(); + } + ASSERT_TRUE(!iter->Valid()); + delete iter; + } + + void TestRandomAccess(random::SimplePhilox* rnd, + const std::vector& keys, const KVMap& data) { + static const bool kVerbose = false; + Iterator* iter = constructor_->NewIterator(); + ASSERT_TRUE(!iter->Valid()); + KVMap::const_iterator model_iter = data.begin(); + if (kVerbose) fprintf(stderr, "---\n"); + for (int i = 0; i < 200; i++) { + const int toss = rnd->Uniform(3); + switch (toss) { + case 0: { + if (iter->Valid()) { + if (kVerbose) fprintf(stderr, "Next\n"); + iter->Next(); + ++model_iter; + ASSERT_EQ(ToString(data, model_iter), ToString(iter)); + } + break; + } + + case 1: { + if (kVerbose) fprintf(stderr, "SeekToFirst\n"); + iter->SeekToFirst(); + model_iter = data.begin(); + ASSERT_EQ(ToString(data, model_iter), ToString(iter)); + break; + } + + case 2: { + string key = PickRandomKey(rnd, keys); + model_iter = data.lower_bound(key); + if (kVerbose) + fprintf(stderr, "Seek '%s'\n", str_util::CEscape(key).c_str()); + iter->Seek(StringPiece(key)); + ASSERT_EQ(ToString(data, model_iter), ToString(iter)); + break; + } + } + } + delete iter; + } + + string ToString(const KVMap& data, const KVMap::const_iterator& it) { + if (it == data.end()) { + return "END"; + } else { + return "'" + it->first + "->" + it->second + "'"; + } + } + + string ToString(const KVMap& data, const KVMap::const_reverse_iterator& it) { + if (it == data.rend()) { + return "END"; + } else { + return "'" + it->first + "->" + it->second + "'"; + } + } + + string ToString(const Iterator* it) { + if (!it->Valid()) { + return "END"; + } else { + return "'" + it->key().ToString() + "->" + it->value().ToString() + "'"; + } + } + + string PickRandomKey(random::SimplePhilox* rnd, + const std::vector& keys) { + if (keys.empty()) { + return "foo"; + } else { + const int index = rnd->Uniform(keys.size()); + string result = keys[index]; + switch (rnd->Uniform(3)) { + case 0: + // Return an existing key + break; + case 1: { + // Attempt to return something smaller than an existing key + if (result.size() > 0 && result[result.size() - 1] > '\0') { + result[result.size() - 1]--; + } + break; + } + case 2: { + // Return something larger than an existing key + Increment(&result); + break; + } + } + return result; + } + } + + private: + Options options_; + Constructor* constructor_; +}; + +// Test empty table/block. +TEST_F(Harness, Empty) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 1, 17); + random::SimplePhilox rnd(&philox); + Test(&rnd); + } +} + +// Special test for a block with no restart entries. The C++ leveldb +// code never generates such blocks, but the Java version of leveldb +// seems to. +TEST_F(Harness, ZeroRestartPointsInBlock) { + char data[sizeof(uint32)]; + memset(data, 0, sizeof(data)); + BlockContents contents; + contents.data = StringPiece(data, sizeof(data)); + contents.cachable = false; + contents.heap_allocated = false; + Block block(contents); + Iterator* iter = block.NewIterator(); + iter->SeekToFirst(); + ASSERT_TRUE(!iter->Valid()); + iter->Seek("foo"); + ASSERT_TRUE(!iter->Valid()); + delete iter; +} + +// Test the empty key +TEST_F(Harness, SimpleEmptyKey) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 1, 17); + random::SimplePhilox rnd(&philox); + Add("", "v"); + Test(&rnd); + } +} + +TEST_F(Harness, SimpleSingle) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 2, 17); + random::SimplePhilox rnd(&philox); + Add("abc", "v"); + Test(&rnd); + } +} + +TEST_F(Harness, SimpleMulti) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 3, 17); + random::SimplePhilox rnd(&philox); + Add("abc", "v"); + Add("abcd", "v"); + Add("ac", "v2"); + Test(&rnd); + } +} + +TEST_F(Harness, SimpleMultiBigValues) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 3, 17); + random::SimplePhilox rnd(&philox); + Add("ainitial", "tiny"); + Add("anext", string(10000000, 'a')); + Add("anext2", string(10000000, 'b')); + Add("azz", "tiny"); + Test(&rnd); + } +} + +TEST_F(Harness, SimpleSpecialKey) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 4, 17); + random::SimplePhilox rnd(&philox); + Add("\xff\xff", "v3"); + Test(&rnd); + } +} + +TEST_F(Harness, Randomized) { + for (int i = 0; i < kNumTestArgs; i++) { + Init(kTestArgList[i]); + random::PhiloxRandom philox(testing::RandomSeed() + 5, 17); + random::SimplePhilox rnd(&philox); + for (int num_entries = 0; num_entries < 2000; + num_entries += (num_entries < 50 ? 1 : 200)) { + if ((num_entries % 10) == 0) { + fprintf(stderr, "case %d of %d: num_entries = %d\n", (i + 1), + int(kNumTestArgs), num_entries); + } + for (int e = 0; e < num_entries; e++) { + string v; + Add(test::RandomKey(&rnd, rnd.Skewed(4)), + test::RandomString(&rnd, rnd.Skewed(5), &v).ToString()); + } + Test(&rnd); + } + } +} + +static bool Between(uint64 val, uint64 low, uint64 high) { + bool result = (val >= low) && (val <= high); + if (!result) { + fprintf(stderr, "Value %llu is not in range [%llu, %llu]\n", + (unsigned long long)(val), (unsigned long long)(low), + (unsigned long long)(high)); + } + return result; +} + +class TableTest {}; + +TEST(TableTest, ApproximateOffsetOfPlain) { + TableConstructor c; + c.Add("k01", "hello"); + c.Add("k02", "hello2"); + c.Add("k03", string(10000, 'x')); + c.Add("k04", string(200000, 'x')); + c.Add("k05", string(300000, 'x')); + c.Add("k06", "hello3"); + c.Add("k07", string(100000, 'x')); + std::vector keys; + KVMap kvmap; + Options options; + options.block_size = 1024; + options.compression = kNoCompression; + c.Finish(options, &keys, &kvmap); + + ASSERT_TRUE(Between(c.ApproximateOffsetOf("abc"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01a"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k02"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k03"), 10, 500)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04"), 10000, 11000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04a"), 210000, 211000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k05"), 210000, 211000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k06"), 510000, 511000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k07"), 510000, 511000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("xyz"), 610000, 612000)); +} + +static bool SnappyCompressionSupported() { + string out; + StringPiece in = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; + return port::Snappy_Compress(in.data(), in.size(), &out); +} + +TEST(TableTest, ApproximateOffsetOfCompressed) { + if (!SnappyCompressionSupported()) { + fprintf(stderr, "skipping compression tests\n"); + return; + } + + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + TableConstructor c; + string tmp; + c.Add("k01", "hello"); + c.Add("k02", test::CompressibleString(&rnd, 0.25, 10000, &tmp)); + c.Add("k03", "hello3"); + c.Add("k04", test::CompressibleString(&rnd, 0.25, 10000, &tmp)); + std::vector keys; + KVMap kvmap; + Options options; + options.block_size = 1024; + options.compression = kSnappyCompression; + c.Finish(options, &keys, &kvmap); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("abc"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01"), 0, 0)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k02"), 10, 100)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k03"), 2000, 3000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04"), 2000, 3000)); + ASSERT_TRUE(Between(c.ApproximateOffsetOf("xyz"), 4000, 6000)); +} + +TEST(TableTest, SeekToFirstKeyDoesNotReadTooMuch) { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + string tmp; + TableConstructor c; + c.Add("k01", "firstvalue"); + c.Add("k03", test::CompressibleString(&rnd, 0.25, 1000000, &tmp)); + c.Add("k04", "abc"); + std::vector keys; + KVMap kvmap; + Options options; + options.block_size = 1024; + options.compression = kNoCompression; + c.Finish(options, &keys, &kvmap); + + Iterator* iter = c.NewIterator(); + iter->Seek("k01"); + delete iter; + // Make sure we don't read the big second block when just trying to + // retrieve the data in the first key + EXPECT_LT(c.BytesRead(), 200); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/two_level_iterator.cc b/tensorflow/core/lib/io/two_level_iterator.cc new file mode 100644 index 0000000000..409baade6d --- /dev/null +++ b/tensorflow/core/lib/io/two_level_iterator.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "tensorflow/core/lib/io/two_level_iterator.h" + +#include "tensorflow/core/lib/io/table.h" +#include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/io/format.h" +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { + +namespace { + +typedef Iterator* (*BlockFunction)(void*, const StringPiece&); + +class TwoLevelIterator : public Iterator { + public: + TwoLevelIterator(Iterator* index_iter, BlockFunction block_function, + void* arg); + + virtual ~TwoLevelIterator(); + + virtual void Seek(const StringPiece& target); + virtual void SeekToFirst(); + virtual void Next(); + + virtual bool Valid() const { + return (data_iter_ == nullptr) ? false : data_iter_->Valid(); + } + virtual StringPiece key() const { + assert(Valid()); + return data_iter_->key(); + } + virtual StringPiece value() const { + assert(Valid()); + return data_iter_->value(); + } + virtual Status status() const { + // It'd be nice if status() returned a const Status& instead of a + // Status + if (!index_iter_->status().ok()) { + return index_iter_->status(); + } else if (data_iter_ != NULL && !data_iter_->status().ok()) { + return data_iter_->status(); + } else { + return status_; + } + } + + private: + void SaveError(const Status& s) { + if (status_.ok() && !s.ok()) status_ = s; + } + void SkipEmptyDataBlocksForward(); + void SetDataIterator(Iterator* data_iter); + void InitDataBlock(); + + BlockFunction block_function_; + void* arg_; + Status status_; + Iterator* index_iter_; + Iterator* data_iter_; // May be NULL + // If data_iter_ is non-NULL, then "data_block_handle_" holds the + // "index_value" passed to block_function_ to create the data_iter_. + string data_block_handle_; +}; + +TwoLevelIterator::TwoLevelIterator(Iterator* index_iter, + BlockFunction block_function, void* arg) + : block_function_(block_function), + arg_(arg), + index_iter_(index_iter), + data_iter_(NULL) {} + +TwoLevelIterator::~TwoLevelIterator() { + delete index_iter_; + delete data_iter_; +} + +void TwoLevelIterator::Seek(const StringPiece& target) { + index_iter_->Seek(target); + InitDataBlock(); + if (data_iter_ != NULL) data_iter_->Seek(target); + SkipEmptyDataBlocksForward(); +} + +void TwoLevelIterator::SeekToFirst() { + index_iter_->SeekToFirst(); + InitDataBlock(); + if (data_iter_ != NULL) data_iter_->SeekToFirst(); + SkipEmptyDataBlocksForward(); +} + +void TwoLevelIterator::Next() { + assert(Valid()); + data_iter_->Next(); + SkipEmptyDataBlocksForward(); +} + +void TwoLevelIterator::SkipEmptyDataBlocksForward() { + while (data_iter_ == NULL || !data_iter_->Valid()) { + // Move to next block + if (!index_iter_->Valid()) { + SetDataIterator(NULL); + return; + } + index_iter_->Next(); + InitDataBlock(); + if (data_iter_ != NULL) data_iter_->SeekToFirst(); + } +} + +void TwoLevelIterator::SetDataIterator(Iterator* data_iter) { + if (data_iter_ != NULL) { + SaveError(data_iter_->status()); + delete data_iter_; + } + data_iter_ = data_iter; +} + +void TwoLevelIterator::InitDataBlock() { + if (!index_iter_->Valid()) { + SetDataIterator(NULL); + } else { + StringPiece handle = index_iter_->value(); + if (data_iter_ != NULL && handle.compare(data_block_handle_) == 0) { + // data_iter_ is already constructed with this iterator, so + // no need to change anything + } else { + Iterator* iter = (*block_function_)(arg_, handle); + data_block_handle_.assign(handle.data(), handle.size()); + SetDataIterator(iter); + } + } +} + +} // namespace + +Iterator* NewTwoLevelIterator(Iterator* index_iter, + BlockFunction block_function, void* arg) { + return new TwoLevelIterator(index_iter, block_function, arg); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/two_level_iterator.h b/tensorflow/core/lib/io/two_level_iterator.h new file mode 100644 index 0000000000..1cc5d2f921 --- /dev/null +++ b/tensorflow/core/lib/io/two_level_iterator.h @@ -0,0 +1,30 @@ +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_ +#define TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_ + +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { + +// Return a new two level iterator. A two-level iterator contains an +// index iterator whose values point to a sequence of blocks where +// each block is itself a sequence of key,value pairs. The returned +// two-level iterator yields the concatenation of all key/value pairs +// in the sequence of blocks. Takes ownership of "index_iter" and +// will delete it when no longer needed. +// +// Uses a supplied function to convert an index_iter value into +// an iterator over the contents of the corresponding block. +extern Iterator* NewTwoLevelIterator( + Iterator* index_iter, + Iterator* (*block_function)(void* arg, const StringPiece& index_value), + void* arg); + +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_ diff --git a/tensorflow/core/lib/jpeg/jpeg_handle.cc b/tensorflow/core/lib/jpeg/jpeg_handle.cc new file mode 100644 index 0000000000..4521be0afb --- /dev/null +++ b/tensorflow/core/lib/jpeg/jpeg_handle.cc @@ -0,0 +1,162 @@ +// This file implements a memory destination for libjpeg +// The design is very similar to jdatadst.c in libjpeg +// These functions are not meant to be used directly, see jpeg_mem.h instead. +// We are filling out stubs required by jpeglib, those stubs are private to +// the implementation, we are just making available JPGMemSrc, JPGMemDest + +#include "tensorflow/core/lib/jpeg/jpeg_handle.h" + +#include +#include + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace jpeg { + +void CatchError(j_common_ptr cinfo) { + (*cinfo->err->output_message)(cinfo); + jmp_buf *jpeg_jmpbuf = reinterpret_cast(cinfo->client_data); + jpeg_destroy(cinfo); + longjmp(*jpeg_jmpbuf, 1); +} + +// ***************************************************************************** +// ***************************************************************************** +// ***************************************************************************** +// Destination functions + +// ----------------------------------------------------------------------------- +void MemInitDestination(j_compress_ptr cinfo) { + MemDestMgr *dest = reinterpret_cast(cinfo->dest); + VLOG(1) << "Initializing buffer=" << dest->bufsize << " bytes"; + dest->pub.next_output_byte = dest->buffer; + dest->pub.free_in_buffer = dest->bufsize; + dest->datacount = 0; + if (dest->dest) { + dest->dest->clear(); + } +} + +// ----------------------------------------------------------------------------- +boolean MemEmptyOutputBuffer(j_compress_ptr cinfo) { + MemDestMgr *dest = reinterpret_cast(cinfo->dest); + VLOG(1) << "Writing " << dest->bufsize << " bytes"; + if (dest->dest) { + dest->dest->append(reinterpret_cast(dest->buffer), dest->bufsize); + } + dest->pub.next_output_byte = dest->buffer; + dest->pub.free_in_buffer = dest->bufsize; + return TRUE; +} + +// ----------------------------------------------------------------------------- +void MemTermDestination(j_compress_ptr cinfo) { + MemDestMgr *dest = reinterpret_cast(cinfo->dest); + VLOG(1) << "Writing " << dest->bufsize - dest->pub.free_in_buffer << " bytes"; + if (dest->dest) { + dest->dest->append(reinterpret_cast(dest->buffer), + dest->bufsize - dest->pub.free_in_buffer); + VLOG(1) << "Total size= " << dest->dest->size(); + } + dest->datacount = dest->bufsize - dest->pub.free_in_buffer; +} + +// ----------------------------------------------------------------------------- +void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize) { + SetDest(cinfo, buffer, bufsize, NULL); +} + +// ----------------------------------------------------------------------------- +void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize, + string *destination) { + MemDestMgr *dest; + if (cinfo->dest == NULL) { + cinfo->dest = reinterpret_cast( + (*cinfo->mem->alloc_small)(reinterpret_cast(cinfo), + JPOOL_PERMANENT, sizeof(MemDestMgr))); + } + + dest = reinterpret_cast(cinfo->dest); + dest->bufsize = bufsize; + dest->buffer = static_cast(buffer); + dest->dest = destination; + dest->pub.init_destination = MemInitDestination; + dest->pub.empty_output_buffer = MemEmptyOutputBuffer; + dest->pub.term_destination = MemTermDestination; +} + +// ***************************************************************************** +// ***************************************************************************** +// ***************************************************************************** +// Source functions + +// ----------------------------------------------------------------------------- +void MemInitSource(j_decompress_ptr cinfo) { + MemSourceMgr *src = reinterpret_cast(cinfo->src); + src->pub.next_input_byte = src->data; + src->pub.bytes_in_buffer = src->datasize; +} + +// ----------------------------------------------------------------------------- +// We emulate the same error-handling as fill_input_buffer() from jdatasrc.c, +// for coherency's sake. +boolean MemFillInputBuffer(j_decompress_ptr cinfo) { + static const JOCTET kEOIBuffer[2] = {0xff, JPEG_EOI}; + MemSourceMgr *src = reinterpret_cast(cinfo->src); + if (src->pub.bytes_in_buffer == 0 && src->pub.next_input_byte == src->data) { + // empty file -> treated as an error. + ERREXIT(cinfo, JERR_INPUT_EMPTY); + return FALSE; + } else if (src->pub.bytes_in_buffer) { + // if there's still some data left, it's probably corrupted + return src->try_recover_truncated_jpeg ? TRUE : FALSE; + } else if (src->pub.next_input_byte != kEOIBuffer && + src->try_recover_truncated_jpeg) { + // In an attempt to recover truncated files, we insert a fake EOI + WARNMS(cinfo, JWRN_JPEG_EOF); + src->pub.next_input_byte = kEOIBuffer; + src->pub.bytes_in_buffer = 2; + return TRUE; + } else { + // We already inserted a fake EOI and it wasn't enough, so this time + // it's really an error. + ERREXIT(cinfo, JERR_FILE_READ); + return FALSE; + } +} + +// ----------------------------------------------------------------------------- +void MemTermSource(j_decompress_ptr cinfo) {} + +// ----------------------------------------------------------------------------- +void MemSkipInputData(j_decompress_ptr cinfo, long jump) { + MemSourceMgr *src = reinterpret_cast(cinfo->src); + src->pub.bytes_in_buffer -= jump; + src->pub.next_input_byte += jump; +} + +// ----------------------------------------------------------------------------- +void SetSrc(j_decompress_ptr cinfo, const void *data, + unsigned long int datasize, bool try_recover_truncated_jpeg) { + MemSourceMgr *src; + + cinfo->src = reinterpret_cast( + (*cinfo->mem->alloc_small)(reinterpret_cast(cinfo), + JPOOL_PERMANENT, sizeof(MemSourceMgr))); + + src = reinterpret_cast(cinfo->src); + src->pub.init_source = MemInitSource; + src->pub.fill_input_buffer = MemFillInputBuffer; + src->pub.skip_input_data = MemSkipInputData; + src->pub.resync_to_restart = jpeg_resync_to_restart; + src->pub.term_source = MemTermSource; + src->data = reinterpret_cast(data); + src->datasize = datasize; + src->pub.bytes_in_buffer = 0; + src->pub.next_input_byte = NULL; + src->try_recover_truncated_jpeg = try_recover_truncated_jpeg; +} + +} // namespace jpeg +} // namespace tensorflow diff --git a/tensorflow/core/lib/jpeg/jpeg_handle.h b/tensorflow/core/lib/jpeg/jpeg_handle.h new file mode 100644 index 0000000000..58f7f6f666 --- /dev/null +++ b/tensorflow/core/lib/jpeg/jpeg_handle.h @@ -0,0 +1,51 @@ +// This file declares the functions and structures for memory I/O with libjpeg +// These functions are not meant to be used directly, see jpeg_mem.h isntead. + +#ifndef TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_ +#define TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_ + +extern "C" { +#include "external/jpeg_archive/jpeg-9a/jinclude.h" +#include "external/jpeg_archive/jpeg-9a/jpeglib.h" +#include "external/jpeg_archive/jpeg-9a/jerror.h" +#include "external/jpeg_archive/jpeg-9a/transupp.h" // for rotations +} + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace jpeg { + +// Handler for fatal JPEG library errors: clean up & return +void CatchError(j_common_ptr cinfo); + +typedef struct { + struct jpeg_destination_mgr pub; + JOCTET *buffer; + int bufsize; + int datacount; + string *dest; +} MemDestMgr; + +typedef struct { + struct jpeg_source_mgr pub; + const unsigned char *data; + unsigned long int datasize; + bool try_recover_truncated_jpeg; +} MemSourceMgr; + +void SetSrc(j_decompress_ptr cinfo, const void *data, + unsigned long int datasize, bool try_recover_truncated_jpeg); + +// JPEG destination: we will store all the data in a buffer "buffer" of total +// size "bufsize", if the buffer overflows, we will be in trouble. +void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize); +// Same as above, except that buffer is only used as a temporary structure and +// is emptied into "destination" as soon as it fills up. +void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize, + string *destination); + +} // namespace jpeg +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_ diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc new file mode 100644 index 0000000000..556f13e388 --- /dev/null +++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc @@ -0,0 +1,557 @@ +// This file defines functions to compress and uncompress JPEG data +// to and from memory, as well as some direct manipulations of JPEG string + +#include "tensorflow/core/lib/jpeg/jpeg_mem.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/core/lib/jpeg/jpeg_handle.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace jpeg { + +// ----------------------------------------------------------------------------- +// Decompression + +namespace { + +enum JPEGErrors { + JPEGERRORS_OK, + JPEGERRORS_UNEXPECTED_END_OF_DATA, + JPEGERRORS_BAD_PARAM +}; + +// Prevent bad compiler behaviour in ASAN mode by wrapping most of the +// arguments in a struct struct. +class FewerArgsForCompiler { + public: + FewerArgsForCompiler(int datasize, const UncompressFlags& flags, int* nwarn, + std::function allocate_output) + : datasize_(datasize), + flags_(flags), + pnwarn_(nwarn), + allocate_output_(allocate_output), + fraction_read_(0.), + height_(0), + stride_(0) { + if (pnwarn_ != nullptr) *pnwarn_ = 0; + } + + const int datasize_; + const UncompressFlags flags_; + int* const pnwarn_; + std::function allocate_output_; + float fraction_read_; // fraction of scanline lines successfully read + int height_; + int stride_; +}; + +uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { + // unpack the argball + const int datasize = argball->datasize_; + const auto& flags = argball->flags_; + const int ratio = flags.ratio; + int components = flags.components; + int stride = flags.stride; // may be 0 + int* const nwarn = argball->pnwarn_; // may be NULL + + // can't decode if the ratio is not recognized by libjpeg + if ((ratio != 1) && (ratio != 2) && (ratio != 4) && (ratio != 8)) { + return nullptr; + } + + // if empty image, return + if (datasize == 0 || srcdata == NULL) return nullptr; + + // Declare temporary buffer pointer here so that we can free on error paths + JSAMPLE* tempdata = nullptr; + + // Initialize libjpeg structures to have a memory source + // Modify the usual jpeg error manager to catch fatal errors. + JPEGErrors error = JPEGERRORS_OK; + struct jpeg_decompress_struct cinfo; + struct jpeg_error_mgr jerr; + cinfo.err = jpeg_std_error(&jerr); + jmp_buf jpeg_jmpbuf; + cinfo.client_data = &jpeg_jmpbuf; + jerr.error_exit = CatchError; + if (setjmp(jpeg_jmpbuf)) { + return nullptr; + } + + jpeg_create_decompress(&cinfo); + SetSrc(&cinfo, srcdata, datasize, flags.try_recover_truncated_jpeg); + jpeg_read_header(&cinfo, TRUE); + + // Set components automatically if desired + if (components == 0) components = cinfo.num_components; + + // set grayscale and ratio parameters + switch (components) { + case 1: + cinfo.out_color_space = JCS_GRAYSCALE; + break; + case 3: + case 4: + if (cinfo.jpeg_color_space == JCS_CMYK || + cinfo.jpeg_color_space == JCS_YCCK) { + // always use cmyk for output in a 4 channel jpeg. libjpeg has a builtin + // decoder. + cinfo.out_color_space = JCS_CMYK; + } else { + cinfo.out_color_space = JCS_RGB; + } + break; + default: + LOG(ERROR) << " Invalid components value " << components << std::endl; + jpeg_destroy_decompress(&cinfo); + return nullptr; + } + cinfo.do_fancy_upsampling = boolean(flags.fancy_upscaling); + cinfo.scale_num = 1; + cinfo.scale_denom = ratio; + // Activating this has a quality/speed trade-off implication: + // cinfo.dct_method = JDCT_IFAST; + + jpeg_start_decompress(&cinfo); + + // check for compatible stride + const int min_stride = cinfo.output_width * components * sizeof(JSAMPLE); + if (stride == 0) { + stride = min_stride; + } else if (stride < min_stride) { + LOG(ERROR) << "Incompatible stride: " << stride << " < " << min_stride; + jpeg_destroy_decompress(&cinfo); + return nullptr; + } + + // Remember stride and height for use in Uncompress + argball->height_ = cinfo.output_height; + argball->stride_ = stride; + + uint8* const dstdata = argball->allocate_output_( + cinfo.output_width, cinfo.output_height, components); + if (dstdata == nullptr) { + jpeg_destroy_decompress(&cinfo); + return nullptr; + } + JSAMPLE* output_line = static_cast(dstdata); + + // Temporary buffer used for CMYK -> RGB conversion. + const bool use_cmyk = (cinfo.out_color_space == JCS_CMYK); + tempdata = use_cmyk ? new JSAMPLE[cinfo.output_width * 4] : NULL; + + // If there is an error reading a line, this aborts the reading. + // Save the fraction of the image that has been read. + argball->fraction_read_ = 1.0; + while (cinfo.output_scanline < cinfo.output_height) { + int num_lines_read = 0; + if (cinfo.out_color_space == JCS_CMYK) { + num_lines_read = jpeg_read_scanlines(&cinfo, &tempdata, 1); + // Convert CMYK to RGB + for (size_t i = 0; i < cinfo.output_width; ++i) { + int c = tempdata[4 * i + 0]; + int m = tempdata[4 * i + 1]; + int y = tempdata[4 * i + 2]; + int k = tempdata[4 * i + 3]; + int r, g, b; + if (cinfo.saw_Adobe_marker) { + r = (k * c) / 255; + g = (k * m) / 255; + b = (k * y) / 255; + } else { + r = (255 - k) * (255 - c) / 255; + g = (255 - k) * (255 - m) / 255; + b = (255 - k) * (255 - y) / 255; + } + output_line[3 * i + 0] = r; + output_line[3 * i + 1] = g; + output_line[3 * i + 2] = b; + } + } else { + num_lines_read = jpeg_read_scanlines(&cinfo, &output_line, 1); + } + // Handle error cases + if (num_lines_read == 0) { + LOG(ERROR) << "Premature end of JPEG data. Stopped at line " + << cinfo.output_scanline << "/" << cinfo.output_height; + if (!flags.try_recover_truncated_jpeg) { + argball->fraction_read_ = + static_cast(cinfo.output_scanline) / cinfo.output_height; + error = JPEGERRORS_UNEXPECTED_END_OF_DATA; + } else { + for (size_t line = cinfo.output_scanline; line < cinfo.output_height; + ++line) { + if (line == 0) { + // If even the first line is missing, fill with black color + memset(output_line, 0, min_stride); + } else { + // else, just replicate the line above. + memcpy(output_line, output_line - stride, min_stride); + } + output_line += stride; + } + argball->fraction_read_ = 1.0; // consider all lines as read + // prevent error-on-exit in libjpeg: + cinfo.output_scanline = cinfo.output_height; + } + break; + } + DCHECK_EQ(num_lines_read, 1); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(output_line, min_stride); + output_line += stride; + } + delete[] tempdata; + + // Convert the RGB data to RGBA, with alpha set to 0xFF to indicate + // opacity. + // RGBRGBRGB... --> RGBARGBARGBA... + if (components == 4) { + // Start on the last line. + JSAMPLE* scanlineptr = + static_cast(dstdata + (cinfo.output_height - 1) * stride); + const JSAMPLE kOpaque = -1; // All ones appropriate for JSAMPLE. + const int right_rgb = (cinfo.output_width - 1) * 3; + const int right_rgba = (cinfo.output_width - 1) * 4; + + for (int y = cinfo.output_height; y-- > 0;) { + // We do all the transformations in place, going backwards for each row. + const JSAMPLE* rgb_pixel = scanlineptr + right_rgb; + JSAMPLE* rgba_pixel = scanlineptr + right_rgba; + scanlineptr -= stride; + for (int x = cinfo.output_width; x-- > 0; + rgba_pixel -= 4, rgb_pixel -= 3) { + // We copy the 3 bytes at rgb_pixel into the 4 bytes at rgba_pixel + // The "a" channel is set to be opaque. + rgba_pixel[3] = kOpaque; + rgba_pixel[2] = rgb_pixel[2]; + rgba_pixel[1] = rgb_pixel[1]; + rgba_pixel[0] = rgb_pixel[0]; + } + } + } + + switch (components) { + case 1: + if (cinfo.output_components != 1) { + error = JPEGERRORS_BAD_PARAM; + } + break; + case 3: + case 4: + if (cinfo.out_color_space == JCS_CMYK) { + if (cinfo.output_components != 4) { + error = JPEGERRORS_BAD_PARAM; + } + } else { + if (cinfo.output_components != 3) { + error = JPEGERRORS_BAD_PARAM; + } + } + break; + default: + // will never happen, should be catched by the previous switch + LOG(ERROR) << "Invalid components value " << components << std::endl; + jpeg_destroy_decompress(&cinfo); + return nullptr; + } + + // save number of warnings if requested + if (nwarn != nullptr) { + *nwarn = cinfo.err->num_warnings; + } + + // Handle errors in JPEG + switch (error) { + case JPEGERRORS_OK: + jpeg_finish_decompress(&cinfo); + break; + case JPEGERRORS_UNEXPECTED_END_OF_DATA: + case JPEGERRORS_BAD_PARAM: + jpeg_abort(reinterpret_cast(&cinfo)); + break; + default: + LOG(ERROR) << "Unhandled case " << error; + break; + } + jpeg_destroy_decompress(&cinfo); + + return dstdata; +} + +} // anonymous namespace + +// ----------------------------------------------------------------------------- +// We do the apparently silly thing of packing 5 of the arguments +// into a structure that is then passed to another routine +// that does all the work. The reason is that we want to catch +// fatal JPEG library errors with setjmp/longjmp, and g++ and +// associated libraries aren't good enough to guarantee that 7 +// parameters won't get clobbered by the longjmp. So we help +// it out a little. +uint8* Uncompress(const void* srcdata, int datasize, + const UncompressFlags& flags, int* nwarn, + std::function allocate_output) { + FewerArgsForCompiler argball(datasize, flags, nwarn, allocate_output); + uint8* const dstdata = UncompressLow(srcdata, &argball); + const float fraction_read = argball.fraction_read_; + if (dstdata == NULL || + fraction_read < std::min(1.0f, flags.min_acceptable_fraction)) { + // Major failure, none or too-partial read returned; get out + return NULL; + } + + // If there was an error in reading the jpeg data, + // set the unread pixels to black + if (fraction_read < 1.0) { + const int first_bad_line = + static_cast(fraction_read * argball.height_); + uint8* start = dstdata + first_bad_line * argball.stride_; + const int nbytes = (argball.height_ - first_bad_line) * argball.stride_; + memset(static_cast(start), 0, nbytes); + } + + return dstdata; +} + +uint8* Uncompress(const void* srcdata, int datasize, + const UncompressFlags& flags, int* pwidth, int* pheight, + int* pcomponents, int* nwarn) { + uint8* buffer = NULL; + uint8* result = + Uncompress(srcdata, datasize, flags, nwarn, + [=, &buffer](int width, int height, int components) { + if (pwidth != nullptr) *pwidth = width; + if (pheight != nullptr) *pheight = height; + if (pcomponents != nullptr) *pcomponents = components; + buffer = new uint8[height * width * components]; + return buffer; + }); + if (!result) delete[] buffer; + return result; +} + +// ---------------------------------------------------------------------------- +// Computes image information from jpeg header. +// Returns true on success; false on failure. +bool GetImageInfo(const void* srcdata, int datasize, int* width, int* height, + int* components) { + // Init in case of failure + if (width) *width = 0; + if (height) *height = 0; + if (components) *components = 0; + + // If empty image, return + if (datasize == 0 || srcdata == NULL) return false; + + // Initialize libjpeg structures to have a memory source + // Modify the usual jpeg error manager to catch fatal errors. + struct jpeg_decompress_struct cinfo; + struct jpeg_error_mgr jerr; + jmp_buf jpeg_jmpbuf; + cinfo.err = jpeg_std_error(&jerr); + cinfo.client_data = &jpeg_jmpbuf; + jerr.error_exit = CatchError; + if (setjmp(jpeg_jmpbuf)) { + return false; + } + + // set up, read header, set image parameters, save size + jpeg_create_decompress(&cinfo); + SetSrc(&cinfo, srcdata, datasize, false); + + jpeg_read_header(&cinfo, TRUE); + jpeg_start_decompress(&cinfo); // required to transfer image size to cinfo + if (width) *width = cinfo.output_width; + if (height) *height = cinfo.output_height; + if (components) *components = cinfo.output_components; + + jpeg_destroy_decompress(&cinfo); + + return true; +} + +// ----------------------------------------------------------------------------- +// Compression + +namespace { +bool CompressInternal(const uint8* srcdata, int width, int height, + const CompressFlags& flags, string* output) { + output->clear(); + const int components = (static_cast(flags.format) & 0xff); + int in_stride = flags.stride; + if (in_stride == 0) { + in_stride = width * (static_cast(flags.format) & 0xff); + } else if (in_stride < width * components) { + LOG(ERROR) << "Incompatible input stride"; + return false; + } + + JOCTET* buffer = 0; + + // NOTE: for broader use xmp_metadata should be made a unicode string + CHECK(srcdata != nullptr); + CHECK(output != nullptr); + // This struct contains the JPEG compression parameters and pointers to + // working space + struct jpeg_compress_struct cinfo; + // This struct represents a JPEG error handler. + struct jpeg_error_mgr jerr; + jmp_buf jpeg_jmpbuf; // recovery point in case of error + + // Step 1: allocate and initialize JPEG compression object + // Use the usual jpeg error manager. + cinfo.err = jpeg_std_error(&jerr); + cinfo.client_data = &jpeg_jmpbuf; + jerr.error_exit = CatchError; + if (setjmp(jpeg_jmpbuf)) { + output->clear(); + delete[] buffer; + return false; + } + + jpeg_create_compress(&cinfo); + + // Step 2: specify data destination + // We allocate a buffer of reasonable size. If we have a small image, just + // estimate the size of the output using the number of bytes of the input. + // If this is getting too big, we will append to the string by chunks of 1MB. + // This seems like a reasonable compromise between performance and memory. + int bufsize = std::min(width * height * components, 1 << 20); + buffer = new JOCTET[bufsize]; + SetDest(&cinfo, buffer, bufsize, output); + + // Step 3: set parameters for compression + cinfo.image_width = width; + cinfo.image_height = height; + switch (components) { + case 1: + cinfo.input_components = 1; + cinfo.in_color_space = JCS_GRAYSCALE; + break; + case 3: + case 4: + cinfo.input_components = 3; + cinfo.in_color_space = JCS_RGB; + break; + default: + LOG(ERROR) << " Invalid components value " << components << std::endl; + output->clear(); + delete[] buffer; + return false; + } + jpeg_set_defaults(&cinfo); + if (flags.optimize_jpeg_size) cinfo.optimize_coding = TRUE; + + cinfo.density_unit = flags.density_unit; // JFIF code for pixel size units: + // 1 = in, 2 = cm + cinfo.X_density = flags.x_density; // Horizontal pixel density + cinfo.Y_density = flags.y_density; // Vertical pixel density + jpeg_set_quality(&cinfo, flags.quality, TRUE); + + if (flags.progressive) { + jpeg_simple_progression(&cinfo); + } + + if (!flags.chroma_downsampling) { + // Turn off chroma subsampling (it is on by default). For more details on + // chroma subsampling, see http://en.wikipedia.org/wiki/Chroma_subsampling. + for (int i = 0; i < cinfo.num_components; ++i) { + cinfo.comp_info[i].h_samp_factor = 1; + cinfo.comp_info[i].v_samp_factor = 1; + } + } + + jpeg_start_compress(&cinfo, TRUE); + + // Embed XMP metadata if any + if (!flags.xmp_metadata.empty()) { + // XMP metadata is embedded in the APP1 tag of JPEG and requires this + // namespace header string (null-terminated) + const string name_space = "http://ns.adobe.com/xap/1.0/"; + const int name_space_length = name_space.size(); + const int metadata_length = flags.xmp_metadata.size(); + const int packet_length = metadata_length + name_space_length + 1; + std::unique_ptr joctet_packet(new JOCTET[packet_length]); + + for (int i = 0; i < name_space_length; i++) { + // Conversion char --> JOCTET + joctet_packet[i] = name_space[i]; + } + joctet_packet[name_space_length] = 0; // null-terminate namespace string + + for (int i = 0; i < metadata_length; i++) { + // Conversion char --> JOCTET + joctet_packet[i + name_space_length + 1] = flags.xmp_metadata[i]; + } + jpeg_write_marker(&cinfo, JPEG_APP0 + 1, joctet_packet.get(), + packet_length); + } + + // JSAMPLEs per row in image_buffer + std::unique_ptr row_temp( + new JSAMPLE[width * cinfo.input_components]); + while (cinfo.next_scanline < cinfo.image_height) { + JSAMPROW row_pointer[1]; // pointer to JSAMPLE row[s] + const uint8* r = &srcdata[cinfo.next_scanline * in_stride]; + uint8* p = static_cast(row_temp.get()); + switch (flags.format) { + case FORMAT_RGBA: { + for (int i = 0; i < width; ++i, p += 3, r += 4) { + p[0] = r[0]; + p[1] = r[1]; + p[2] = r[2]; + } + row_pointer[0] = row_temp.get(); + break; + } + case FORMAT_ABGR: { + for (int i = 0; i < width; ++i, p += 3, r += 4) { + p[0] = r[3]; + p[1] = r[2]; + p[2] = r[1]; + } + row_pointer[0] = row_temp.get(); + break; + } + default: { + row_pointer[0] = reinterpret_cast(const_cast(r)); + } + } + CHECK_EQ(jpeg_write_scanlines(&cinfo, row_pointer, 1), 1); + } + jpeg_finish_compress(&cinfo); + + // release JPEG compression object + jpeg_destroy_compress(&cinfo); + delete[] buffer; + return true; +} + +} // anonymous namespace + +// ----------------------------------------------------------------------------- + +bool Compress(const void* srcdata, int width, int height, + const CompressFlags& flags, string* output) { + return CompressInternal(static_cast(srcdata), width, height, + flags, output); +} + +string Compress(const void* srcdata, int width, int height, + const CompressFlags& flags) { + string temp; + CompressInternal(static_cast(srcdata), width, height, flags, + &temp); + // If CompressInternal fails, temp will be empty. + return temp; +} + +} // namespace jpeg +} // namespace tensorflow diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.h b/tensorflow/core/lib/jpeg/jpeg_mem.h new file mode 100644 index 0000000000..19ba7d4acf --- /dev/null +++ b/tensorflow/core/lib/jpeg/jpeg_mem.h @@ -0,0 +1,130 @@ +// This file defines functions to compress and uncompress JPEG files +// to and from memory. It provides interfaces for raw images +// (data array and size fields). +// Direct manipulation of JPEG strings are supplied: Flip, Rotate, Crop.. + +#ifndef TENSORFLOW_LIB_JPEG_JPEG_MEM_H_ +#define TENSORFLOW_LIB_JPEG_JPEG_MEM_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace jpeg { + +// Flags for Uncompress +struct UncompressFlags { + // ratio can be 1, 2, 4, or 8 and represent the denominator for the scaling + // factor (eg ratio = 4 means that the resulting image will be at 1/4 original + // size in both directions). + int ratio = 1; + + // The number of bytes per pixel (1, 3 or 4), or 0 for autodetect. + int components = 0; + + // If true, decoder will use a slower but nicer upscaling of the chroma + // planes (yuv420/422 only). + bool fancy_upscaling = true; + + // If true, will attempt to fill in missing lines of truncated files + bool try_recover_truncated_jpeg = false; + + // The minimum required fraction of lines read before the image is accepted. + float min_acceptable_fraction = 1.0; + + // The distance in bytes from one scanline to the other. Should be at least + // equal to width*components*sizeof(JSAMPLE). If 0 is passed, the stride + // used will be this minimal value. + int stride = 0; +}; + +// Uncompress some raw JPEG data given by the pointer srcdata and the length +// datasize. +// - width and height are the address where to store the size of the +// uncompressed image in pixels. May be nullptr. +// - components is the address where the number of read components are +// stored. This is *output only*: to request a specific number of +// components use flags.components. May be nullptr. +// - nwarn is the address in which to store the number of warnings. +// May be nullptr. +// The function returns a pointer to the raw uncompressed data or NULL if +// there was an error. The caller of the function is responsible for +// freeing the memory (using delete []). +uint8* Uncompress(const void* srcdata, int datasize, + const UncompressFlags& flags, int* width, int* height, + int* components, // Output only: useful with autodetect + int* nwarn); + +// Version of Uncompress that allocates memory via a callback. The callback +// arguments are (width, height, components). If the size is known ahead of +// time this function can return an existing buffer; passing a callback allows +// the buffer to be shaped based on the JPEG header. The caller is responsible +// for freeing the memory *even along error paths*. +uint8* Uncompress(const void* srcdata, int datasize, + const UncompressFlags& flags, int* nwarn, + std::function allocate_output); + +// Read jpeg header and get image information. Returns true on success. +// The width, height, and components points may be null. +bool GetImageInfo(const void* srcdata, int datasize, int* width, int* height, + int* components); + +// Note: (format & 0xff) = number of components (<=> bytes per pixels) +enum Format { + FORMAT_GRAYSCALE = 0x001, // 1 byte/pixel + FORMAT_RGB = 0x003, // 3 bytes/pixel RGBRGBRGBRGB... + FORMAT_RGBA = 0x004, // 4 bytes/pixel RGBARGBARGBARGBA... + FORMAT_ABGR = 0x104 // 4 bytes/pixel ABGRABGRABGR... +}; + +// Flags for compression +struct CompressFlags { + // Encoding of the input data for compression + Format format; + + // Quality of the compression from 0-100 + int quality = 95; + + // If true, create a jpeg image that loads progressively + bool progressive = false; + + // If true, reduce jpeg size without changing quality (at the cost of CPU/RAM) + bool optimize_jpeg_size = false; + + // See http://en.wikipedia.org/wiki/Chroma_subsampling + bool chroma_downsampling = true; + + // Resolution + int density_unit = 1; // 1 = in, 2 = cm + int x_density = 300; + int y_density = 300; + + // If not empty, embed this XMP metadata in the image header + StringPiece xmp_metadata; + + // The distance in bytes from one scanline to the other. Should be at least + // equal to width*components*sizeof(JSAMPLE). If 0 is passed, the stride + // used will be this minimal value. + int stride = 0; +}; + +// Compress some raw image given in srcdata, the data is a 2D array of size +// stride*height with one of the formats enumerated above. +// The encoded data is returned as a string. +// If not empty, XMP metadata can be embedded in the image header +// On error, returns the empty string (which is never a valid jpeg). +string Compress(const void* srcdata, int width, int height, + const CompressFlags& flags); + +// On error, returns false and sets output to empty. +bool Compress(const void* srcdata, int width, int height, + const CompressFlags& flags, string* output); + +} // namespace jpeg +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_JPEG_JPEG_MEM_H_ diff --git a/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc b/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc new file mode 100644 index 0000000000..23e72f9d57 --- /dev/null +++ b/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc @@ -0,0 +1,304 @@ +#include "tensorflow/core/lib/jpeg/jpeg_mem.h" + +#include +#include +#include +#include + +#include + +#include "tensorflow/core/lib/jpeg/jpeg_handle.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" +#include + +#include "tensorflow/core/lib/core/casts.h" + +namespace tensorflow { +namespace jpeg { +namespace { + +const char kTestData[] = "tensorflow/core/lib/jpeg/testdata/"; + +int ComputeSumAbsoluteDifference(const uint8* a, const uint8* b, int width, + int height, int a_stride, int b_stride) { + int totalerr = 0; + for (int i = 0; i < height; i++) { + const uint8* const pa = a + i * a_stride; + const uint8* const pb = b + i * b_stride; + for (int j = 0; j < 3 * width; j++) { + totalerr += abs(static_cast(pa[j]) - static_cast(pb[j])); + } + } + return totalerr; +} + +// Reads the contents of the file into output +void ReadFileToStringOrDie(Env* env, const string& filename, string* output) { + TF_CHECK_OK(ReadFileToString(env, filename, output)); +} + +void TestJPEG(Env* env, const string& jpegfile) { + // Read the data from the jpeg file into memory + string jpeg; + ReadFileToStringOrDie(Env::Default(), jpegfile, &jpeg); + const int fsize = jpeg.size(); + const uint8* const temp = bit_cast(jpeg.data()); + + // try partial decoding (half of the data) + int w, h, c; + std::unique_ptr imgdata; + + UncompressFlags flags; + flags.components = 3; + + // set min_acceptable_fraction to something insufficient + flags.min_acceptable_fraction = 0.8; + imgdata.reset(Uncompress(temp, fsize / 2, flags, &w, &h, &c, NULL)); + CHECK(imgdata.get() == NULL); + + // now, use a value that makes fsize/2 be enough for a black-filling + flags.min_acceptable_fraction = 0.01; + imgdata.reset(Uncompress(temp, fsize / 2, flags, &w, &h, &c, NULL)); + CHECK(imgdata.get() != NULL); + + // finally, uncompress the whole data + flags.min_acceptable_fraction = 1.0; + imgdata.reset(Uncompress(temp, fsize, flags, &w, &h, &c, NULL)); + CHECK(imgdata.get() != NULL); + + // Uncompress the data to RGBA, too + flags.min_acceptable_fraction = 1.0; + flags.components = 4; + imgdata.reset(Uncompress(temp, fsize, flags, &w, &h, &c, NULL)); + CHECK(imgdata.get() != NULL); +} + +TEST(JpegMemTest, Jpeg) { + Env* env = Env::Default(); + const string data_path = kTestData; + + // Name of a valid jpeg file on the disk + TestJPEG(env, data_path + "jpeg_merge_test1.jpg"); + + // Exercise CMYK machinery as well + TestJPEG(env, data_path + "jpeg_merge_test1_cmyk.jpg"); +} + +TEST(JpegMemTest, Jpeg2) { + // create known data, for size in_w x in_h + const int in_w = 256; + const int in_h = 256; + const int stride1 = 3 * in_w; + const std::unique_ptr refdata1(new uint8[stride1 * in_h]); + for (int i = 0; i < in_h; i++) { + for (int j = 0; j < in_w; j++) { + const int offset = i * stride1 + 3 * j; + refdata1[offset + 0] = i; + refdata1[offset + 1] = j; + refdata1[offset + 2] = static_cast((i + j) >> 1); + } + } + + // duplicate with weird input stride + const int stride2 = 3 * 357; + const std::unique_ptr refdata2(new uint8[stride2 * in_h]); + for (int i = 0; i < in_h; i++) { + memcpy(&refdata2[i * stride2], &refdata1[i * stride1], 3 * in_w); + } + + // Test compression + string cpdata1, cpdata2; + { + const string kXMP = "XMP_TEST_123"; + + // Compress it to JPEG + CompressFlags flags; + flags.format = FORMAT_RGB; + flags.quality = 97; + flags.xmp_metadata = kXMP; + cpdata1 = Compress(refdata1.get(), in_w, in_h, flags); + flags.stride = stride2; + cpdata2 = Compress(refdata2.get(), in_w, in_h, flags); + // Different input stride shouldn't change the output + CHECK_EQ(cpdata1, cpdata2); + + // Verify valid XMP. + CHECK_NE(string::npos, cpdata1.find(kXMP)); + + // Test the other API, where a storage string is supplied + string cptest; + flags.stride = 0; + Compress(refdata1.get(), in_w, in_h, flags, &cptest); + CHECK_EQ(cptest, cpdata1); + flags.stride = stride2; + Compress(refdata2.get(), in_w, in_h, flags, &cptest); + CHECK_EQ(cptest, cpdata2); + } + + // Uncompress twice: once with 3 components and once with autodetect + std::unique_ptr imgdata1; + for (const int components : {0, 3}) { + // Uncompress it + UncompressFlags flags; + flags.components = components; + int w, h, c; + imgdata1.reset( + Uncompress(cpdata1.c_str(), cpdata1.length(), flags, &w, &h, &c, NULL)); + + // Check obvious formatting stuff + CHECK_EQ(w, in_w); + CHECK_EQ(h, in_h); + CHECK_EQ(c, 3); + CHECK(imgdata1.get()); + + // Compare the two images + const int totalerr = ComputeSumAbsoluteDifference( + imgdata1.get(), refdata1.get(), in_w, in_h, stride1, stride1); + CHECK_LE(totalerr, 85000); + } + + // check the second image too. Should be bitwise identical to the first. + // uncompress using a weird stride + { + UncompressFlags flags; + flags.stride = 3 * 411; + const std::unique_ptr imgdata2(new uint8[flags.stride * in_h]); + CHECK(imgdata2.get() == Uncompress(cpdata2.c_str(), cpdata2.length(), flags, + NULL, [&imgdata2](int w, int h, int c) { + CHECK_EQ(w, in_w); + CHECK_EQ(h, in_h); + CHECK_EQ(c, 3); + return imgdata2.get(); + })); + const int totalerr = ComputeSumAbsoluteDifference( + imgdata1.get(), imgdata2.get(), in_w, in_h, stride1, flags.stride); + CHECK_EQ(totalerr, 0); + } +} + +// Takes JPEG data and reads its headers to determine whether or not the JPEG +// was chroma downsampled. +bool IsChromaDownsampled(const string& jpegdata) { + // Initialize libjpeg structures to have a memory source + // Modify the usual jpeg error manager to catch fatal errors. + struct jpeg_decompress_struct cinfo; + struct jpeg_error_mgr jerr; + jmp_buf jpeg_jmpbuf; + cinfo.err = jpeg_std_error(&jerr); + cinfo.client_data = &jpeg_jmpbuf; + jerr.error_exit = CatchError; + if (setjmp(jpeg_jmpbuf)) return false; + + // set up, read header, set image parameters, save size + jpeg_create_decompress(&cinfo); + SetSrc(&cinfo, jpegdata.c_str(), jpegdata.size(), false); + + jpeg_read_header(&cinfo, TRUE); + jpeg_start_decompress(&cinfo); // required to transfer image size to cinfo + const int components = cinfo.output_components; + if (components == 1) return false; + + // Check validity + CHECK_EQ(3, components); + CHECK_EQ(cinfo.comp_info[1].h_samp_factor, cinfo.comp_info[2].h_samp_factor) + << "The h sampling factors should be the same."; + CHECK_EQ(cinfo.comp_info[1].v_samp_factor, cinfo.comp_info[2].v_samp_factor) + << "The v sampling factors should be the same."; + for (int i = 0; i < components; ++i) { + CHECK_GT(cinfo.comp_info[i].h_samp_factor, 0) << "Invalid sampling factor."; + CHECK_EQ(cinfo.comp_info[i].h_samp_factor, cinfo.comp_info[i].v_samp_factor) + << "The sampling factor should be the same in both directions."; + } + + // We're downsampled if we use fewer samples for color than for brightness. + // Do this before deallocating cinfo. + const bool downsampled = + cinfo.comp_info[1].h_samp_factor < cinfo.comp_info[0].h_samp_factor; + + jpeg_destroy_decompress(&cinfo); + return downsampled; +} + +TEST(JpegMemTest, ChromaDownsampling) { + // Read the data from a test jpeg file into memory + const string jpegfile = string(kTestData) + "jpeg_merge_test1.jpg"; + string jpeg; + ReadFileToStringOrDie(Env::Default(), jpegfile, &jpeg); + + // Verify that compressing the JPEG with chroma downsampling works. + // + // First, uncompress the JPEG. + UncompressFlags unflags; + unflags.components = 3; + int w, h, c, num_warnings; + std::unique_ptr uncompressed(Uncompress( + jpeg.c_str(), jpeg.size(), unflags, &w, &h, &c, &num_warnings)); + CHECK(uncompressed.get() != NULL); + CHECK_EQ(num_warnings, 0); + + // Recompress the JPEG with and without chroma downsampling + for (const bool downsample : {false, true}) { + CompressFlags flags; + flags.format = FORMAT_RGB; + flags.quality = 85; + flags.chroma_downsampling = downsample; + string recompressed; + Compress(uncompressed.get(), w, h, flags, &recompressed); + CHECK(!recompressed.empty()); + CHECK_EQ(IsChromaDownsampled(recompressed), downsample); + } +} + +void TestBadJPEG(Env* env, const string& bad_jpeg_file, int expected_width, + int expected_height, const string& reference_RGB_file, + const bool try_recover_truncated_jpeg) { + string jpeg; + ReadFileToStringOrDie(env, bad_jpeg_file, &jpeg); + + UncompressFlags flags; + flags.components = 3; + flags.try_recover_truncated_jpeg = try_recover_truncated_jpeg; + + int width, height, components; + std::unique_ptr imgdata; + imgdata.reset(Uncompress(jpeg.c_str(), jpeg.size(), flags, &width, &height, + &components, NULL)); + if (expected_width > 0) { // we expect the file to decode into 'something' + CHECK_EQ(width, expected_width); + CHECK_EQ(height, expected_height); + CHECK_EQ(components, 3); + CHECK(imgdata.get()); + if (!reference_RGB_file.empty()) { + string ref; + ReadFileToStringOrDie(env, reference_RGB_file, &ref); + CHECK(!memcmp(ref.data(), imgdata.get(), ref.size())); + } + } else { // no decodable + CHECK(!imgdata.get()) << "file:" << bad_jpeg_file; + } +} + +TEST(JpegMemTest, BadJpeg) { + Env* env = Env::Default(); + const string data_path = kTestData; + + // Test corrupt file + TestBadJPEG(env, data_path + "bad_huffman.jpg", 1024, 768, "", false); + TestBadJPEG(env, data_path + "corrupt.jpg", 0 /*120*/, 90, "", false); + + // Truncated files, undecodable because of missing lines: + TestBadJPEG(env, data_path + "corrupt34_2.jpg", 0, 3300, "", false); + TestBadJPEG(env, data_path + "corrupt34_3.jpg", 0, 3300, "", false); + TestBadJPEG(env, data_path + "corrupt34_4.jpg", 0, 3300, "", false); + + // Try in 'recover' mode now: + TestBadJPEG(env, data_path + "corrupt34_2.jpg", 2544, 3300, "", true); + TestBadJPEG(env, data_path + "corrupt34_3.jpg", 2544, 3300, "", true); + TestBadJPEG(env, data_path + "corrupt34_4.jpg", 2544, 3300, "", true); +} + +} // namespace +} // namespace jpeg +} // namespace tensorflow diff --git a/tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg b/tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg new file mode 100644 index 0000000000..ef5b6f12c5 Binary files /dev/null and b/tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg differ diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt.jpg new file mode 100644 index 0000000000..5e2fe6c56f Binary files /dev/null and b/tensorflow/core/lib/jpeg/testdata/corrupt.jpg differ diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg new file mode 100644 index 0000000000..4211155c45 Binary files /dev/null and b/tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg differ diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg new file mode 100644 index 0000000000..c1c2a9d1e1 Binary files /dev/null and b/tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg differ diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg new file mode 100644 index 0000000000..b8e7308ba0 Binary files /dev/null and b/tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg differ diff --git a/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg new file mode 100644 index 0000000000..5e348a12fd Binary files /dev/null and b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg differ diff --git a/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg new file mode 100644 index 0000000000..15f895960d Binary files /dev/null and b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg differ diff --git a/tensorflow/core/lib/png/png_io.cc b/tensorflow/core/lib/png/png_io.cc new file mode 100644 index 0000000000..43b84e41e0 --- /dev/null +++ b/tensorflow/core/lib/png/png_io.cc @@ -0,0 +1,385 @@ +// Functions to read and write images in PNG format. + +#include +#include +#include +#include +#include +// NOTE(skal): we don't '#include ' before png/png.h as it otherwise +// provokes a compile error. We instead let png.h include what is needed. + +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/png/png_io.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" // endian +#include "external/png_archive/libpng-1.2.53/png.h" + +namespace tensorflow { +namespace png { + +//////////////////////////////////////////////////////////////////////////////// +// Encode an 8- or 16-bit rgb/grayscale image to PNG string +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +#define PTR_INC(type, ptr, del) (ptr = \ + reinterpret_cast(reinterpret_cast(ptr) + (del))) +#define CPTR_INC(type, ptr, del) (ptr = \ + reinterpret_cast(reinterpret_cast(ptr) + (del))) + +// Convert from 8 bit components to 16. This works in-place. +static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes, + int width, int height, uint16* p16, + int p16_row_bytes) { + // Adjust pointers to copy backwards + width *= num_comps; + CPTR_INC(uint8, p8, (height - 1) * p8_row_bytes + + (width - 1) * sizeof(*p8)); + PTR_INC(uint16, p16, (height - 1) * p16_row_bytes + + (width - 1) * sizeof(*p16)); + int bump8 = width * sizeof(*p8) - p8_row_bytes; + int bump16 = width * sizeof(*p16) - p16_row_bytes; + for (; height-- != 0; + CPTR_INC(uint8, p8, bump8), PTR_INC(uint16, p16, bump16)) { + for (int w = width; w-- != 0; --p8, --p16) { + uint pix = *p8; + pix |= pix << 8; + *p16 = static_cast(pix); + } + } +} + +#undef PTR_INC +#undef CPTR_INC + +void ErrorHandler(png_structp png_ptr, png_const_charp msg) { + DecodeContext* const ctx = bit_cast(png_get_io_ptr(png_ptr)); + ctx->error_condition = true; + // To prevent log spam, errors are logged as VLOG(1) instead of ERROR. + VLOG(1) << "PNG error: " << msg; + longjmp(png_jmpbuf(png_ptr), 1); +} + +void WarningHandler(png_structp png_ptr, png_const_charp msg) { + LOG(WARNING) << "PNG warning: " << msg; +} + +void StringReader(png_structp png_ptr, + png_bytep data, png_size_t length) { + DecodeContext* const ctx = bit_cast(png_get_io_ptr(png_ptr)); + if (static_cast(ctx->data_left) < length) { + if (!ctx->error_condition) { + VLOG(1) << "PNG read decoding error"; + ctx->error_condition = true; + } + memset(data, 0, length); + } else { + memcpy(data, ctx->data, length); + ctx->data += length; + ctx->data_left -= length; + } +} + +void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) { + string* const s = bit_cast(png_get_io_ptr(png_ptr)); + s->append(bit_cast(data), length); +} + +void StringWriterFlush(png_structp png_ptr) { +} + +char* check_metadata_string(const string& s) { + const char* const c_string = s.c_str(); + const size_t length = s.size(); + if (strlen(c_string) != length) { + LOG(WARNING) << "Warning! Metadata contains \\0 character(s)."; + } + return const_cast(c_string); +} + +} // namespace + +// We move CommonInitDecode() and CommonFinishDecode() +// out of the CommonDecode() template to save code space. +void CommonFreeDecode(DecodeContext* context) { + if (context->png_ptr) { + png_destroy_read_struct(&context->png_ptr, + context->info_ptr ? &context->info_ptr : NULL, 0); + context->png_ptr = nullptr; + context->info_ptr = nullptr; + } +} + +bool DecodeHeader(StringPiece png_string, int* width, int* height, + int* components, int* channel_bit_depth, + std::vector >* metadata) { + DecodeContext context; + // Ask for 16 bits even if there may be fewer. This assures that sniffing + // the metadata will succeed in all cases. + // + // TODO(skal): CommonInitDecode() mixes the operation of sniffing the + // metadata with setting up the data conversions. These should be separated. + constexpr int kDesiredNumChannels = 1; + constexpr int kDesiredChannelBits = 16; + if (!CommonInitDecode(png_string, kDesiredNumChannels, kDesiredChannelBits, + &context)) { + return false; + } + CHECK_NOTNULL(width); + *width = static_cast(context.width); + CHECK_NOTNULL(height); + *height = static_cast(context.height); + if (components != NULL) { + switch (context.color_type) { + case PNG_COLOR_TYPE_PALETTE: + *components = (context.info_ptr->valid & PNG_INFO_tRNS) ? 4 : 3; + break; + case PNG_COLOR_TYPE_GRAY: + *components = 1; + break; + case PNG_COLOR_TYPE_GRAY_ALPHA: + *components = 2; + break; + case PNG_COLOR_TYPE_RGB: + *components = 3; + break; + case PNG_COLOR_TYPE_RGB_ALPHA: + *components = 4; + break; + default: + *components = 0; + break; + } + } + if (channel_bit_depth != NULL) { + *channel_bit_depth = context.bit_depth; + } + if (metadata != NULL) { + metadata->clear(); + for (int i = 0; i < context.info_ptr->num_text; i++) { + const png_text& text = context.info_ptr->text[i]; + metadata->push_back(std::make_pair(text.key, text.text)); + } + } + CommonFreeDecode(&context); + return true; +} + +bool CommonInitDecode(StringPiece png_string, int desired_channels, + int desired_channel_bits, DecodeContext* context) { + CHECK(desired_channel_bits == 8 || desired_channel_bits == 16) + << "desired_channel_bits = " << desired_channel_bits; + CHECK(0 <= desired_channels && desired_channels <= 4) << "desired_channels = " + << desired_channels; + context->error_condition = false; + context->channels = desired_channels; + context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context, + ErrorHandler, WarningHandler); + if (!context->png_ptr) { + VLOG(1) << ": DecodePNG <- png_create_read_struct failed"; + return false; + } + if (setjmp(png_jmpbuf(context->png_ptr))) { + VLOG(1) << ": DecodePNG error trapped."; + CommonFreeDecode(context); + return false; + } + context->info_ptr = png_create_info_struct(context->png_ptr); + if (!context->info_ptr || context->error_condition) { + VLOG(1) << ": DecodePNG <- png_create_info_struct failed"; + CommonFreeDecode(context); + return false; + } + context->data = bit_cast(png_string.data()); + context->data_left = png_string.size(); + png_set_read_fn(context->png_ptr, context, StringReader); + png_read_info(context->png_ptr, context->info_ptr); + png_get_IHDR(context->png_ptr, context->info_ptr, + &context->width, &context->height, + &context->bit_depth, &context->color_type, + 0, 0, 0); + if (context->error_condition) { + VLOG(1) << ": DecodePNG <- error during header parsing."; + CommonFreeDecode(context); + return false; + } + if (context->width <= 0 || context->height <= 0) { + VLOG(1) << ": DecodePNG <- invalid dimensions"; + CommonFreeDecode(context); + return false; + } + if (context->channels == 0) { // Autodetect number of channels + context->channels = context->info_ptr->channels; + } + const bool has_tRNS = (context->info_ptr->valid & PNG_INFO_tRNS) != 0; + const bool has_alpha = (context->color_type & PNG_COLOR_MASK_ALPHA) != 0; + if ((context->channels & 1) == 0) { // We desire alpha + if (has_alpha) { // There is alpha + } else if (has_tRNS) { + png_set_tRNS_to_alpha(context->png_ptr); // Convert transparency to alpha + } else { + png_set_add_alpha( + context->png_ptr, (1 << context->bit_depth) - 1, PNG_FILLER_AFTER); + } + } else { // We don't want alpha + if (has_alpha || has_tRNS) { // There is alpha + png_set_strip_alpha(context->png_ptr); // Strip alpha + } + } + + // If we only want 8 bits, but are given 16, strip off the LS 8 bits + if (context->bit_depth > 8 && desired_channel_bits <= 8) + png_set_strip_16(context->png_ptr); + + context->need_to_synthesize_16 = + (context->bit_depth <= 8 && desired_channel_bits == 16); + + png_set_packing(context->png_ptr); + context->num_passes = png_set_interlace_handling(context->png_ptr); + png_read_update_info(context->png_ptr, context->info_ptr); + +#ifdef IS_LITTLE_ENDIAN + if (desired_channel_bits > 8) + png_set_swap(context->png_ptr); +#endif // IS_LITTLE_ENDIAN + + // convert palette to rgb(a) if needs be. + if (context->color_type == PNG_COLOR_TYPE_PALETTE) + png_set_palette_to_rgb(context->png_ptr); + + // handle grayscale case for source or destination + const bool want_gray = (context->channels < 3); + const bool is_gray = !(context->color_type & PNG_COLOR_MASK_COLOR); + if (is_gray) { // upconvert gray to 8-bit if needed. + if (context->bit_depth < 8) + png_set_gray_1_2_4_to_8(context->png_ptr); + } + if (want_gray) { // output is grayscale + if (!is_gray) + png_set_rgb_to_gray(context->png_ptr, 1, 0.299, 0.587); // 601, JPG + } else { // output is rgb(a) + if (is_gray) + png_set_gray_to_rgb(context->png_ptr); // Enable gray -> RGB conversion + } + return true; +} + +bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) { + CHECK_NOTNULL(data); + + // we need to re-set the jump point so that we trap the errors + // within *this* function (and not CommonInitDecode()) + if (setjmp(png_jmpbuf(context->png_ptr))) { + VLOG(1) << ": DecodePNG error trapped."; + CommonFreeDecode(context); + return false; + } + // png_read_row() takes care of offsetting the pointer based on interlacing + for (int p = 0; p < context->num_passes; ++p) { + png_bytep row = data; + for (int h = context->height; h-- != 0; row += row_bytes) { + png_read_row(context->png_ptr, row, NULL); + } + } + + context->info_ptr->valid |= PNG_INFO_IDAT; + png_read_end(context->png_ptr, context->info_ptr); + + // Clean up. + const bool ok = !context->error_condition; + CommonFreeDecode(context); + + // Synthesize 16 bits from 8 if requested. + if (context->need_to_synthesize_16) + Convert8to16(bit_cast(data), context->channels, row_bytes, + context->width, context->height, bit_cast(data), + row_bytes); + return ok; +} + +bool WriteImageToBuffer( + const void* image, int width, int height, int row_bytes, int num_channels, + int channel_bits, int compression, string* png_string, + const std::vector >* metadata) { + CHECK_NOTNULL(image); + CHECK_NOTNULL(png_string); + // Although this case is checked inside png.cc and issues an error message, + // that error causes memory corruption. + if (width == 0 || height == 0) + return false; + + png_string->resize(0); + png_infop info_ptr = NULL; + png_structp png_ptr = + png_create_write_struct(PNG_LIBPNG_VER_STRING, + NULL, ErrorHandler, WarningHandler); + if (png_ptr == NULL) return false; + if (setjmp(png_jmpbuf(png_ptr))) { + png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : NULL); + return false; + } + info_ptr = png_create_info_struct(png_ptr); + if (info_ptr == NULL) { + png_destroy_write_struct(&png_ptr, NULL); + return false; + } + + int color_type = -1; + switch (num_channels) { + case 1: + color_type = PNG_COLOR_TYPE_GRAY; + break; + case 2: + color_type = PNG_COLOR_TYPE_GRAY_ALPHA; + break; + case 3: + color_type = PNG_COLOR_TYPE_RGB; + break; + case 4: + color_type = PNG_COLOR_TYPE_RGB_ALPHA; + break; + default: + png_destroy_write_struct(&png_ptr, &info_ptr); + return false; + } + + png_set_write_fn(png_ptr, png_string, StringWriter, StringWriterFlush); + if (compression < 0) compression = Z_DEFAULT_COMPRESSION; + png_set_compression_level(png_ptr, compression); + png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL); + // There used to be a call to png_set_filter here turning off filtering + // entirely, but it produced pessimal compression ratios. I'm not sure + // why it was there. + png_set_IHDR(png_ptr, info_ptr, width, height, channel_bits, color_type, + PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT, + PNG_FILTER_TYPE_DEFAULT); + // If we have metadata write to it. + if (metadata && !metadata->empty()) { + std::vector text; + for (const auto& pair : *metadata) { + png_text txt; + txt.compression = PNG_TEXT_COMPRESSION_NONE; + txt.key = check_metadata_string(pair.first); + txt.text = check_metadata_string(pair.second); + text.push_back(txt); + } + png_set_text(png_ptr, info_ptr, &text[0], text.size()); + } + + png_write_info(png_ptr, info_ptr); +#ifdef IS_LITTLE_ENDIAN + if (channel_bits > 8) + png_set_swap(png_ptr); +#endif // IS_LITTLE_ENDIAN + + png_byte* row = reinterpret_cast(const_cast(image)); + for (; height--; row += row_bytes) png_write_row(png_ptr, row); + png_write_end(png_ptr, NULL); + + png_destroy_write_struct(&png_ptr, &info_ptr); + return true; +} + +} // namespace png +} // namespace tensorflow diff --git a/tensorflow/core/lib/png/png_io.h b/tensorflow/core/lib/png/png_io.h new file mode 100644 index 0000000000..df9bff7be8 --- /dev/null +++ b/tensorflow/core/lib/png/png_io.h @@ -0,0 +1,88 @@ +// Functions to read and write images in PNG format. +// +// The advantage over image/codec/png{enc,dec}ocder.h is that this library +// supports both 8 and 16 bit images. +// +// The decoding routine accepts binary image data as a StringPiece. These are +// implicitly constructed from strings or char* so they're completely +// transparent to the caller. They're also very cheap to construct so this +// doesn't introduce any additional overhead. +// +// The primary benefit of StringPieces being, in this case, that APIs already +// returning StringPieces (e.g., Bigtable Scanner) or Cords (e.g., IOBuffer; +// only when they're flat, though) or protocol buffer fields typed to either of +// these can be decoded without copying the data into a C++ string. + +#ifndef TENSORFLOW_LIB_PNG_PNG_IO_H_ +#define TENSORFLOW_LIB_PNG_PNG_IO_H_ + +#include +#include +#include + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "external/png_archive/libpng-1.2.53/png.h" + +namespace tensorflow { +namespace png { + +// Handy container for decoding informations and struct pointers +struct DecodeContext { + const uint8* data; + int data_left; + png_structp png_ptr; + png_infop info_ptr; + png_uint_32 width, height; + int num_passes; + int color_type; + int bit_depth; + int channels; + bool need_to_synthesize_16; + bool error_condition; + DecodeContext() : png_ptr(NULL), info_ptr(NULL) {} +}; + +bool DecodeHeader(StringPiece png_string, int* width, int* height, + int* components, int* channel_bit_depth, + std::vector >* metadata); + +// Sample usage for reading PNG: +// +// string png_string; /* fill with input PNG format data */ +// DecodeContext context; +// CHECK(CommonInitDecode(png_string, 3 /*RGB*/, 8 /*uint8*/, &context)); +// char* image_buffer = new char[3*context.width*context.height]; +// CHECK(CommonFinishDecode(bit_cast(image_buffer), +// 3*context.width /*stride*/, &context)); +// +// desired_channels may be 0 to detected it from the input. + +bool CommonInitDecode(StringPiece png_string, int desired_channels, + int desired_channel_bits, DecodeContext* context); + +bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context); + +// Normally called automatically from CommonFinishDecode. If CommonInitDecode +// is called but not CommonFinishDecode, call this to clean up. Safe to call +// extra times. +void CommonFreeDecode(DecodeContext* context); + +// Sample usage for writing PNG: +// +// uint16* image_buffer = new uint16[width*height]; /* fill with pixels */ +// string png_string; +// CHECK(WriteImageToBuffer(image_buffer, width, height, 2*width /*stride*/, +// 1 /*gray*/, 16 /*uint16*/, &png_string, NULL)); +// +// compression is in [-1,9], where 0 is fast and weak compression, 9 is slow +// and strong, and -1 is the zlib default. + +bool WriteImageToBuffer( + const void* image, int width, int height, int row_bytes, int num_channels, + int channel_bits, int compression, string* png_string, + const std::vector >* metadata); + +} // namespace png +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_PNG_PNG_IO_H_ diff --git a/tensorflow/core/lib/png/testdata/lena_gray.png b/tensorflow/core/lib/png/testdata/lena_gray.png new file mode 100644 index 0000000000..8bc73159b0 Binary files /dev/null and b/tensorflow/core/lib/png/testdata/lena_gray.png differ diff --git a/tensorflow/core/lib/png/testdata/lena_rgba.png b/tensorflow/core/lib/png/testdata/lena_rgba.png new file mode 100644 index 0000000000..79f1f84a62 Binary files /dev/null and b/tensorflow/core/lib/png/testdata/lena_rgba.png differ diff --git a/tensorflow/core/lib/random/distribution_sampler.cc b/tensorflow/core/lib/random/distribution_sampler.cc new file mode 100644 index 0000000000..341f1bd595 --- /dev/null +++ b/tensorflow/core/lib/random/distribution_sampler.cc @@ -0,0 +1,80 @@ +#include "tensorflow/core/lib/random/distribution_sampler.h" + +#include +#include + +namespace tensorflow { +namespace random { + +DistributionSampler::DistributionSampler( + const gtl::ArraySlice& weights) { + DCHECK(!weights.empty()); + int n = weights.size(); + num_ = n; + data_.reset(new std::pair[n]); + + std::unique_ptr pr(new double[n]); + + double sum = 0.0; + for (int i = 0; i < n; i++) { + sum += weights[i]; + set_alt(i, -1); + } + + // These are long/short items - called high/low because of reserved keywords. + std::vector high; + high.reserve(n); + std::vector low; + low.reserve(n); + + // compute propotional weights + for (int i = 0; i < n; i++) { + double p = (weights[i] * n) / sum; + pr[i] = p; + if (p < 1.0) { + low.push_back(i); + } else { + high.push_back(i); + } + } + + // Now pair high with low. + while (!high.empty() && !low.empty()) { + int l = low.back(); + low.pop_back(); + int h = high.back(); + high.pop_back(); + + set_alt(l, h); + DCHECK_GE(pr[h], 1.0); + double remaining = pr[h] - (1.0 - pr[l]); + pr[h] = remaining; + + if (remaining < 1.0) { + low.push_back(h); + } else { + high.push_back(h); + } + } + // Transfer pr to prob with rounding errors. + for (int i = 0; i < n; i++) { + set_prob(i, pr[i]); + } + // Because of rounding errors, both high and low may have elements, that are + // close to 1.0 prob. + for (size_t i = 0; i < high.size(); i++) { + int idx = high[i]; + set_prob(idx, 1.0); + // set alt to self to prevent rounding errors returning 0 + set_alt(idx, idx); + } + for (size_t i = 0; i < low.size(); i++) { + int idx = low[i]; + set_prob(idx, 1.0); + // set alt to self to prevent rounding errors returning 0 + set_alt(idx, idx); + } +} + +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/distribution_sampler.h b/tensorflow/core/lib/random/distribution_sampler.h new file mode 100644 index 0000000000..ab9598a205 --- /dev/null +++ b/tensorflow/core/lib/random/distribution_sampler.h @@ -0,0 +1,79 @@ +// DistributionSampler allows generating a discrete random variable with a given +// distribution. +// The values taken by the variable are [0, N) and relative weights for each +// value are specified using a vector of size N. +// +// The Algorithm takes O(N) time to precompute data at construction time and +// takes O(1) time (2 random number generation, 2 lookups) for each sample. +// The data structure takes O(N) memory. +// +// In contrast, util/random/weighted-picker.h provides O(lg N) sampling. +// The advantage of that implementation is that weights can be adjusted +// dynamically, while DistributionSampler doesn't allow weight adjustment. +// +// The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2. + +#ifndef TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ +#define TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ + +#include +#include +#include + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace random { + +class DistributionSampler { + public: + explicit DistributionSampler(const gtl::ArraySlice& weights); + + ~DistributionSampler() {} + + int Sample(SimplePhilox* rand) const { + float r = rand->RandFloat(); + // Since n is typically low, we don't bother with UnbiasedUniform. + int idx = rand->Uniform(num_); + if (r < prob(idx)) return idx; + // else pick alt from that bucket. + DCHECK_NE(-1, alt(idx)); + return alt(idx); + } + + int num() const { return num_; } + + private: + float prob(int idx) const { + DCHECK_LT(idx, num_); + return data_[idx].first; + } + + int alt(int idx) const { + DCHECK_LT(idx, num_); + return data_[idx].second; + } + + void set_prob(int idx, float f) { + DCHECK_LT(idx, num_); + data_[idx].first = f; + } + + void set_alt(int idx, int val) { + DCHECK_LT(idx, num_); + data_[idx].second = val; + } + + int num_; + std::unique_ptr[]> data_; + + TF_DISALLOW_COPY_AND_ASSIGN(DistributionSampler); +}; + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ diff --git a/tensorflow/core/lib/random/distribution_sampler_test.cc b/tensorflow/core/lib/random/distribution_sampler_test.cc new file mode 100644 index 0000000000..d61a8daa0f --- /dev/null +++ b/tensorflow/core/lib/random/distribution_sampler_test.cc @@ -0,0 +1,90 @@ +#include "tensorflow/core/lib/random/distribution_sampler.h" + +#include +#include +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include + +namespace tensorflow { +namespace random { + +class DistributionSamplerTest : public ::testing::Test { + protected: + // Returns the Chi-Squared statistic for the two distributions. + float TestWeights(const std::vector& weights, int trials_per_bin) { + int iters = weights.size() * trials_per_bin; + std::unique_ptr counts(new float[weights.size()]); + memset(counts.get(), 0, sizeof(float) * weights.size()); + DistributionSampler sampler(weights); + PhiloxRandom philox(testing::RandomSeed(), 17); + SimplePhilox random(&philox); + for (int i = 0; i < iters; i++) { + int r = sampler.Sample(&random); + EXPECT_LT(r, weights.size()); + EXPECT_GE(r, 0); + counts[r] += 1.0; + } + float chi2 = 0.0; + for (size_t i = 0; i < weights.size(); i++) { + counts[i] /= iters; + float err = (counts[i] - weights[i]); + chi2 += (err * err) / weights[i]; + } + return chi2; + } + + void TestDistribution(float* arr, int n) { + std::vector w; + w.reserve(n); + for (int i = 0; i < n; i++) { + w.push_back(arr[i]); + } + float var = TestWeights(w, 1000); + if (var < 0.001) return; + // Maybe a statistical skew. Let's try more iterations. + var = TestWeights(w, 100000); + if (var < 0.001) return; + EXPECT_TRUE(false) << "Chi2 is " << var << " in " << n * 100000 + << "iterations"; + } +}; + +TEST_F(DistributionSamplerTest, KnownDistribution) { + float kEven2[] = {0.5, 0.5}; + float kEven3[] = {0.33333333, 0.33333333, 0.33333333}; + float kEven4[] = {0.25, 0.25, 0.25, 0.25}; + + float kDist1[] = {0.8, 0.15, 0.05}; + + TestDistribution(kEven2, TF_ARRAYSIZE(kEven2)); + TestDistribution(kEven3, TF_ARRAYSIZE(kEven3)); + TestDistribution(kEven4, TF_ARRAYSIZE(kEven4)); + TestDistribution(kDist1, TF_ARRAYSIZE(kDist1)); +} + +static void BM_DistributionSampler(int iters, int n) { + testing::StopTiming(); + PhiloxRandom philox(173, 371); + SimplePhilox rand(&philox); + std::vector weights(n, 0); + for (int i = 0; i < n; i++) { + weights[i] = rand.Uniform(100); + } + DistributionSampler picker(weights); + testing::StartTiming(); + int r = 0; + for (int i = 0; i < iters; i++) { + r |= picker.Sample(&rand); + } + CHECK_NE(r, kint32max); +} + +BENCHMARK(BM_DistributionSampler)->Arg(10)->Arg(100)->Arg(1000); + +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/exact_uniform_int.h b/tensorflow/core/lib/random/exact_uniform_int.h new file mode 100644 index 0000000000..616354cc5c --- /dev/null +++ b/tensorflow/core/lib/random/exact_uniform_int.h @@ -0,0 +1,68 @@ +// Exact uniform integers using rejection sampling + +#ifndef TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_ +#define TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_ + +#include + +namespace tensorflow { +namespace random { + +template +UintType ExactUniformInt(const UintType n, const RandomBits& random) { + static_assert(std::is_unsigned::value, + "UintType must be an unsigned int"); + static_assert(std::is_same::value, + "random() should return UintType"); + if (n == 0) { + // Consume a value anyway + // TODO(irving): Assert n != 0, since this case makes no sense. + return random() * n; + } else if (0 == (n & (n - 1))) { + // N is a power of two, so just mask off the lower bits. + return random() & (n - 1); + } else { + // Reject all numbers that skew the distribution towards 0. + + // random's output is uniform in the half-open interval [0, 2^{bits}). + // For any interval [m,n), the number of elements in it is n-m. + + const UintType range = ~static_cast(0); + const UintType rem = (range % n) + 1; + UintType rnd; + + // rem = ((2^bits-1) \bmod n) + 1 + // 1 <= rem <= n + + // NB: rem == n is impossible, since n is not a power of 2 (from + // earlier check). + + do { + rnd = random(); // rnd uniform over [0, 2^{bits}) + } while (rnd < rem); // reject [0, rem) + // rnd is uniform over [rem, 2^{bits}) + // + // The number of elements in the half-open interval is + // + // 2^{bits} - rem = 2^{bits} - ((2^{bits}-1) \bmod n) - 1 + // = 2^{bits}-1 - ((2^{bits}-1) \bmod n) + // = n \cdot \lfloor (2^{bits}-1)/n \rfloor + // + // therefore n evenly divides the number of integers in the + // interval. + // + // The function v \rightarrow v % n takes values from [bias, + // 2^{bits}) to [0, n). Each integer in the range interval [0, n) + // will have exactly \lfloor (2^{bits}-1)/n \rfloor preimages from + // the domain interval. + // + // Therefore, v % n is uniform over [0, n). QED. + + return rnd % n; + } +} + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_ diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h new file mode 100644 index 0000000000..2c3cd0c4b9 --- /dev/null +++ b/tensorflow/core/lib/random/philox_random.h @@ -0,0 +1,232 @@ +// Implement the Philox algorithm to generate random numbers in parallel. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf + +#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_ +#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_ + +#include + +#include "tensorflow/core/platform/port.h" + +// Function qualifiers that need to work on both CPU and GPU. +#ifdef __CUDA_ARCH__ +// For nvcc. +#define PHILOX_DEVICE_FUNC __host__ __device__ +#define PHILOX_INLINE __inline__ +#else +// For non-nvcc. +#define PHILOX_DEVICE_FUNC +#define PHILOX_INLINE inline +#endif +#define PHILOX_DEVICE_INLINE PHILOX_DEVICE_FUNC PHILOX_INLINE + +#include + +namespace tensorflow { +namespace random { + +// A class that represents an inline array. It can be used on both CPU and GPU, +// and also trivially copyable between CPU and GPU. +// Arguments: +// T: the array element type; +// ElementCount: the fixed size of the array; +template +class Array { + public: + PHILOX_DEVICE_INLINE Array() { + for (int i = 0; i < ElementCount; ++i) { + data_[i] = T(); + } + } + + PHILOX_DEVICE_INLINE const T& operator[](int index) const { + return data_[index]; + } + + PHILOX_DEVICE_INLINE T& operator[](int index) { return data_[index]; } + + size_t size() const { return ElementCount; } + + private: + T data_[ElementCount]; +}; + +// A class that encapsulates all the states for a random number generator using +// the philox_4x32_10 algorithm. Each invocation returns a 128-bit random bits +// in the form of four uint32. +// There are multiple variants of this algorithm, we picked the 4x32_10 version +// that is most suited for our applications. +// Since this class is meant to be copied between CPU to GPU, it maintains a +// value semantics. +// +// For example: To use this class and populate an array of 1024 randoms on CPU +// with two threads, +// +// void Fill(PhiloxRandom rnd, uint32* output, int start, int limit) { +// assert(start % 4 == 0); +// assert(limit % 4 == 0); +// rnd.Skip(start / 4); +// for (int i = start; i < limit; i += 4) { +// auto sample = rnd(); +// ... copy sample[0..3] to output[i..i+3] +// } +// } +// +// PhiloxRandom rng(seed); +// PhiloxRandom rng_copy = rng; +// rng.Skip(1000/4); +// +// ... schedule Fill(rng_copy, output, 0, 512) in thread 1; +// ... schedule Fill(rng_copy, output, 512, 1024) in thread 2; +// ... wait for thread 1 & 2 to finish executing Fill(). +// +// NOTE: +// 1. PhiloxRandom is trivially copyable. +// 2. PhiloxRandom is compilable by gcc and nvcc. +class PhiloxRandom { + public: + typedef Array ResultType; + typedef uint32 ResultElementType; + // The number of elements that will be returned. + static const int kResultElementCount = 4; + + PHILOX_DEVICE_INLINE + PhiloxRandom() {} + + PHILOX_DEVICE_INLINE + explicit PhiloxRandom(uint64 seed) { + key_[0] = static_cast(seed); + key_[1] = static_cast(seed >> 32); + } + + PHILOX_DEVICE_INLINE + explicit PhiloxRandom(uint64 seed_lo, uint64 seed_hi) { + key_[0] = static_cast(seed_lo); + key_[1] = static_cast(seed_lo >> 32); + counter_[2] = static_cast(seed_hi); + counter_[3] = static_cast(seed_hi >> 32); + } + + // Skip the specified number of samples of 128-bits in the current stream. + PHILOX_DEVICE_INLINE + void Skip(uint64 count) { + const uint32 count_lo = static_cast(count); + uint32 count_hi = static_cast(count >> 32); + + counter_[0] += count_lo; + if (counter_[0] < count_lo) { + ++count_hi; + } + + counter_[1] += count_hi; + if (counter_[1] < count_hi) { + if (++counter_[2] == 0) { + ++counter_[3]; + } + } + } + + // Returns a group of four random numbers using the underlying Philox + // algorithm. + PHILOX_DEVICE_INLINE ResultType operator()() { + ResultType counter = counter_; + Key key = key_; + + // Run the single rounds for ten times. Manually unrolling the loop + // for better performance. + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + RaiseKey(&key); + counter = ComputeSingleRound(counter, key); + + SkipOne(); + + return counter; + } + + private: + // The type for the 64-bit key stored in the form of two 32-bit uint + // that are used in the diffusion process. + typedef Array Key; + + // We use the same constants as recommended by the original paper. + static const uint32 kPhiloxW32A = 0x9E3779B9; + static const uint32 kPhiloxW32B = 0xBB67AE85; + static const uint32 kPhiloxM4x32A = 0xD2511F53; + static const uint32 kPhiloxM4x32B = 0xCD9E8D57; + + // Helper function to skip the next sample of 128-bits in the current stream. + PHILOX_DEVICE_INLINE void SkipOne() { + if (++counter_[0] == 0) { + if (++counter_[1] == 0) { + if (++counter_[2] == 0) { + ++counter_[3]; + } + } + } + } + + // Helper function to return the lower and higher 32-bits from two 32-bit + // integer multiplications. + PHILOX_DEVICE_INLINE + static void MultiplyHighLow(uint32 a, uint32 b, uint32* result_low, + uint32* result_high) { +#ifndef __GCUDACC__ + const uint64 product = static_cast(a) * b; + *result_low = static_cast(product); + *result_high = static_cast(product >> 32); +#else + *result_low = a * b; + *result_high = __umulhi(a, b); +#endif + } + + // Helper function for a single round of the underlying Philox algorithm. + PHILOX_DEVICE_INLINE static ResultType ComputeSingleRound( + const ResultType& counter, const Key& key) { + uint32 lo0; + uint32 hi0; + MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0); + + uint32 lo1; + uint32 hi1; + MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1); + + ResultType result; + result[0] = hi1 ^ counter[1] ^ key[0]; + result[1] = lo1; + result[2] = hi0 ^ counter[3] ^ key[1]; + result[3] = lo0; + return result; + } + + PHILOX_DEVICE_INLINE void RaiseKey(Key* key) { + (*key)[0] += kPhiloxW32A; + (*key)[1] += kPhiloxW32B; + } + + private: + ResultType counter_; + Key key_; +}; + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_ diff --git a/tensorflow/core/lib/random/philox_random_test.cc b/tensorflow/core/lib/random/philox_random_test.cc new file mode 100644 index 0000000000..997c0263b7 --- /dev/null +++ b/tensorflow/core/lib/random/philox_random_test.cc @@ -0,0 +1,58 @@ +#include "tensorflow/core/lib/random/philox_random.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/random/philox_random_test_utils.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include + +namespace tensorflow { +namespace random { +namespace { + +// A trivial distribution that just returns the PhiloxRandom as a distribution +class TrivialPhiloxDistribution { + public: + // The number of elements that will be returned. + static constexpr int kResultElementCount = PhiloxRandom::kResultElementCount; + typedef PhiloxRandom::ResultType ResultType; + typedef PhiloxRandom::ResultElementType ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(PhiloxRandom* gen) { return (*gen)(); } +}; + +// This test checks that skipping certain number of samples, is equivalent to +// generate the same number of samples without skipping. +TEST(PhiloxRandomTest, SkipMatchTest) { + constexpr int count = 1024; + constexpr int skip_count = 2048; + + uint64 test_seed = GetTestSeed(); + std::vector v1(count); + { + PhiloxRandom gen(test_seed); + gen.Skip(skip_count / 4); + FillRandoms(gen, &v1[0], v1.size()); + } + + std::vector v2(count + skip_count); + { + PhiloxRandom gen(test_seed); + FillRandoms(gen, &v2[0], v2.size()); + } + + for (int i = 0; i < count; ++i) { + ASSERT_EQ(v1[i], v2[i + skip_count]); + } +} + +} // namespace +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/philox_random_test_utils.h b/tensorflow/core/lib/random/philox_random_test_utils.h new file mode 100644 index 0000000000..d22f6b36e4 --- /dev/null +++ b/tensorflow/core/lib/random/philox_random_test_utils.h @@ -0,0 +1,36 @@ +#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ +#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ + +#include + +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace random { + +// Return a random seed. +inline uint64 GetTestSeed() { return New64(); } + +// A utility function to fill the given array with samples from the given +// distribution. +template +void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p, + int64 size) { + const int granularity = Distribution::kResultElementCount; + + CHECK(size % granularity == 0) << " size: " << size + << " granularity: " << granularity; + + Distribution dist; + for (int i = 0; i < size; i += granularity) { + const auto sample = dist(&gen); + std::copy(&sample[0], &sample[0] + granularity, &p[i]); + } +} + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ diff --git a/tensorflow/core/lib/random/random.cc b/tensorflow/core/lib/random/random.cc new file mode 100644 index 0000000000..2959b05382 --- /dev/null +++ b/tensorflow/core/lib/random/random.cc @@ -0,0 +1,22 @@ +#include "tensorflow/core/lib/random/random.h" + +#include +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace random { + +std::mt19937_64* InitRng() { + std::random_device device("/dev/random"); + return new std::mt19937_64(device()); +} + +uint64 New64() { + static std::mt19937_64* rng = InitRng(); + static mutex mu; + mutex_lock l(mu); + return (*rng)(); +} + +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/random.h b/tensorflow/core/lib/random/random.h new file mode 100644 index 0000000000..1a20436c4e --- /dev/null +++ b/tensorflow/core/lib/random/random.h @@ -0,0 +1,16 @@ +#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_H_ +#define TENSORFLOW_LIB_RANDOM_RANDOM_H_ + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace random { + +// Return a 64-bit random value. Different sequences are generated +// in different processes. +uint64 New64(); + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_RANDOM_H_ diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h new file mode 100644 index 0000000000..caafcde513 --- /dev/null +++ b/tensorflow/core/lib/random/random_distributions.h @@ -0,0 +1,361 @@ +#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ +#define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ + +#include +#include +#include + +#include "tensorflow/core/lib/random/philox_random.h" + +namespace tensorflow { +namespace random { + +// Helper function to convert a 32-bit integer to a float between [0..1). +PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x); +// Helper function to convert two 32-bit integers to a double between [0..1). +PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1); + +// A class that generates uniform distribution random numbers from the +// underlying random integer generator. +// Arguments: +// Generator: a generator type that returns a number of uint32 upon each +// each invocation. It needs to define kResultElementCount for the +// sample count for each invocation, and ResultType for actual +// returned sample type. +// RealType: the data type of the real numberes that will be returned by the +// distribution. This could be either float or double for now. +// This class is meant to be implemented through specialization. The default +// is not defined by design. +template +class UniformDistribution; + +template +class UniformDistribution { + public: + // The number of elements that will be returned. + static const int kResultElementCount = Generator::kResultElementCount; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = false; + typedef Array ResultType; + typedef float ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(Generator* gen) { + typename Generator::ResultType sample = (*gen)(); + ResultType result; + for (int i = 0; i < kResultElementCount; ++i) { + result[i] = Uint32ToFloat(sample[i]); + } + return result; + } +}; + +template +class UniformDistribution { + public: + // The number of elements that will be returned. + static const int kResultElementCount = Generator::kResultElementCount / 2; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = false; + typedef Array ResultType; + typedef double ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(Generator* gen) { + typename Generator::ResultType sample = (*gen)(); + ResultType result; + for (int i = 0; i < kResultElementCount; ++i) { + result[i] = Uint64ToDouble(sample[2 * i], sample[2 * i + 1]); + } + return result; + } +}; + +// A class that adapts the underlying native multiple samples to return a single +// sample at a time. +template +class SingleSampleAdapter { + public: + // The number of elements that will be returned. + static const int kResultElementCount = 1; + // The number of elements that will be returned by the underlying generator. + static const int kNativeElementCount = Generator::kResultElementCount; + typedef typename Generator::ResultElementType ResultType; + typedef typename Generator::ResultElementType ResultElementType; + + PHILOX_DEVICE_INLINE + explicit SingleSampleAdapter(Generator* gen) + : generator_(gen), used_result_index_(Generator::kResultElementCount) {} + + PHILOX_DEVICE_INLINE + ResultType operator()() { + if (used_result_index_ == Generator::kResultElementCount) { + unused_results_ = (*generator_)(); + used_result_index_ = 0; + } + + return unused_results_[used_result_index_++]; + } + + private: + Generator* generator_; + typename Generator::ResultType unused_results_; + int used_result_index_; +}; + +// A class that generates unit normal distribution random numbers from the +// underlying random integer generator. +// Arguments: +// Generator: a generator type that returns a number of uint32 upon each +// each invocation. It needs to define kResultElementCount for the +// sample count for each invocation, and ResultType for actual +// returned sample type. +// RealType: the data type of the real numberes that will be returned by the +// distribution. This could be either float or double for now. +// This class is meant to be implemented through specialization. The default +// is not defined by design. +template +class NormalDistribution; + +PHILOX_DEVICE_INLINE +void BoxMullerFloat(uint32 x0, uint32 x1, float* f0, float* f1); + +PHILOX_DEVICE_INLINE +void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0, + double* d1); + +template +class NormalDistribution { + public: + // The number of elements that will be returned. + static const int kResultElementCount = Generator::kResultElementCount; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = false; + typedef Array ResultType; + typedef float ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(Generator* gen) { + typename Generator::ResultType sample = (*gen)(); + ResultType result; + for (int i = 0; i < kResultElementCount; i += 2) { + BoxMullerFloat(sample[i], sample[i + 1], &result[i], &result[i + 1]); + } + return result; + } +}; + +template +class NormalDistribution { + public: + // The number of elements that will be returned. + static const int kResultElementCount = Generator::kResultElementCount / 2; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = false; + typedef Array ResultType; + typedef double ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(Generator* gen) { + typename Generator::ResultType sample = (*gen)(); + ResultType result; + for (int i = 0; i < kResultElementCount; i += 2) { + const int i2 = 2 * i; + BoxMullerDouble(sample[i2], sample[i2 + 1], sample[i2 + 2], + sample[i2 + 3], &result[i], &result[i + 1]); + } + return result; + } +}; + +// A class that returns standard normal distribution between +// [-kTruncateValue, kTruncateValue]. +// Arguments: +// Generator: a generator type that returns a number of uint32 upon each +// each invocation. It needs to define kResultElementCount for the +// sample count for each invocation, and ResultType for actual +// returned sample type. +// RealType: the data type of the real numberes that will be returned by the +// distribution. This could be either float or double for now. +// This class is meant to be implemented through specialization. The default +// is not defined by design. +template +class TruncatedNormalDistribution; + +// Partial specialization for float. +template +class TruncatedNormalDistribution { + public: + // The number of elements that will be returned. + static const int kResultElementCount = + SingleSampleGenerator::kNativeElementCount; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = true; + // The threshold where the normal distribution is truncated. + const float kTruncateValue = 2.0f; + + typedef Array ResultType; + typedef float ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(SingleSampleGenerator* gen) { + ResultType results; + int index = 0; + while (true) { + // Repeatedly take samples from the normal distribution, until we have + // the desired number of elements that fall within the pre-defined cutoff + // threshold. + const uint32 x0 = (*gen)(); + const uint32 x1 = (*gen)(); + float f[2]; + BoxMullerFloat(x0, x1, &f[0], &f[1]); + + for (int i = 0; i < 2; ++i) { + if (fabs(f[i]) < kTruncateValue) { + results[index++] = f[i]; + if (index >= kResultElementCount) { + return results; + } + } + } + } + } +}; + +// Partial specialization for double. +template +class TruncatedNormalDistribution { + public: + // The number of elements that will be returned. + static const int kResultElementCount = + (SingleSampleGenerator::kNativeElementCount > 1) + ? SingleSampleGenerator::kNativeElementCount / 2 + : 1; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = true; + typedef Array ResultType; + typedef double ResultElementType; + const double kTruncateValue = 2.0; + + PHILOX_DEVICE_INLINE + ResultType operator()(SingleSampleGenerator* gen) { + ResultType results; + int index = 0; + while (1) { + const uint32 x0 = (*gen)(); + const uint32 x1 = (*gen)(); + const uint32 x2 = (*gen)(); + const uint32 x3 = (*gen)(); + double d[2]; + BoxMullerDouble(x0, x1, x2, x3, &d[0], &d[1]); + + for (int i = 0; i < 2; ++i) { + if (fabs(d[i]) < kTruncateValue) { + results[index++] = d[i]; + if (index >= kResultElementCount) { + return results; + } + } + } + } + } +}; + +// Helper function to convert two 32-bit uniform integers to two floats +// under the unit normal distribution. +PHILOX_DEVICE_INLINE +void BoxMullerFloat(uint32 x0, uint32 x1, float* f0, float* f1) { + // This function implements the Box-Muller transform: + // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form + // Do not send a really small number to log(). + // We cannot mark "epsilon" as "static const" because NVCC would complain + const float epsilon = 1.0e-7f; + float u1 = Uint32ToFloat(x0); + if (u1 < epsilon) { + u1 = epsilon; + } + const float v1 = 2.0f * M_PI * Uint32ToFloat(x1); + const float u2 = sqrt(-2.0f * log(u1)); +#if defined(__linux) + sincosf(v1, f0, f1); +#else + *f0 = sinf(v1); + *f1 = cosf(v1); +#endif + *f0 *= u2; + *f1 *= u2; +} + +// Helper function to convert four 32-bit uniform integers to two doubles +// under the unit normal distribution. +PHILOX_DEVICE_INLINE +void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0, + double* d1) { + // This function implements the Box-Muller transform: + // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form + // Do not send a really small number to log(). + // We cannot mark "epsilon" as "static const" because NVCC would complain + const double epsilon = 1.0e-7; + double u1 = Uint64ToDouble(x0, x1); + if (u1 < epsilon) { + u1 = epsilon; + } + const double v1 = 2 * M_PI * Uint64ToDouble(x2, x3); + const double u2 = sqrt(-2.0 * log(u1)); +#if defined(__linux) + sincos(v1, d0, d1); +#else + *d0 = sin(v1); + *d1 = cos(v1); +#endif + *d0 *= u2; + *d1 *= u2; +} + +// Helper function to convert an 32-bit integer to a float between [0..1). +PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x) { + // IEEE754 floats are formatted as follows (MSB first): + // sign(1) exponent(8) mantissa(23) + // Conceptually construct the following: + // sign == 0 + // exponent == 127 -- an excess 127 representation of a zero exponent + // mantissa == 23 random bits + const uint32 man = x & 0x7fffffu; // 23 bit mantissa + const uint32 exp = static_cast(127); + const uint32 val = (exp << 23) | man; + + // Assumes that endian-ness is same for float and uint32. + float result; + memcpy(&result, &val, sizeof(val)); + return result - 1.0f; +} + +// Helper function to convert two 32-bit integers to a double between [0..1). +PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1) { + // IEEE754 doubles are formatted as follows (MSB first): + // sign(1) exponent(11) mantissa(52) + // Conceptually construct the following: + // sign == 0 + // exponent == 1023 -- an excess 1023 representation of a zero exponent + // mantissa == 52 random bits + const uint32 mhi = x0 & 0xfffffu; // upper 20 bits of mantissa + const uint32 mlo = x1; // lower 32 bits of mantissa + const uint64 man = (static_cast(mhi) << 32) | mlo; // mantissa + const uint64 exp = static_cast(1023); + const uint64 val = (exp << 52) | man; + // Assumes that endian-ness is same for double and uint64. + double result; + memcpy(&result, &val, sizeof(val)); + return result - 1.0; +} + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ diff --git a/tensorflow/core/lib/random/random_distributions_test.cc b/tensorflow/core/lib/random/random_distributions_test.cc new file mode 100644 index 0000000000..3ce86a907a --- /dev/null +++ b/tensorflow/core/lib/random/random_distributions_test.cc @@ -0,0 +1,270 @@ +#include "tensorflow/core/lib/random/random_distributions.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/philox_random_test_utils.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/logging.h" +#include + +namespace tensorflow { +namespace random { +namespace { + +// The largest z-value we want to tolerate. Since the z-test approximates a +// unit normal distribution, it should almost definitely never exceed 6. +static constexpr float kZLimit = 6.0; + +// A utility function to fill the given array with samples from the given +// distribution, using the single adatper of the underlying generator +template +void FillRandomsWithSingles(PhiloxRandom gen, + typename Distribution::ResultElementType* p, + int64 size) { + int granularity = Distribution::kResultElementCount; + + CHECK(size % granularity == 0) << " size: " << size + << " granularity: " << granularity; + + SingleSampleAdapter single_samples(&gen); + + Distribution dist; + for (int i = 0; i < size; i += granularity) { + auto sample = dist(&single_samples); + std::copy(&sample[0], &sample[0] + granularity, &p[i]); + } +} + +// Check the given array of samples matches the given theoretical moment +// function at different orders. The test is considered passing if the z-tests +// of all statistical moments are all below z_limit. +// typename T in the template argument could be either float or double. +// Arguments: +// samples: an array of samples to be tested for their statistical properties; +// theoretical_moments: a functor that can calculate arbitrary order of +// of the given distribution; +// max_moments: the largest moments of the uniform distribution to be tested; +// stride: the distance between samples to check for statistical properties +// 0 means the n-th moment of each sample +// any other strides tests for spatial correlation between samples; +// z_limit: the maximum z-test we would consider the test to pass; +template +bool CheckSamplesMoments(const std::vector& samples, + std::function theoretical_moments, + int max_moments, int stride, T z_limit) { + const T* const samples_data = &samples[0]; + const int samples_size = samples.size(); + std::vector moments(max_moments + 1); + double* const moments_data = &moments[0]; + std::vector moments_sample_count(max_moments + 1); + int* const moments_sample_count_data = &moments_sample_count[0]; + + for (int k = 0; k < samples_size; ++k) { + double moment = 1.; + for (int i = 0; i <= max_moments; ++i) { + int index = k + i * stride; + if (index >= samples_size) { + break; + } + // moments[i] store the i-th order measured moments. + // bypass std::vector::opeartor[] because they are too slow in the debug + // mode, given the large number of samples. + moments_data[i] += moment; + ++moments_sample_count_data[i]; + moment *= samples_data[index]; + } + } + + // normalize the moments + for (int i = 0; i <= max_moments; ++i) { + moments[i] /= moments_sample_count[i]; + } + + bool status = true; + + for (int i = 1; i <= max_moments; ++i) { + // Calculate the theoretical mean and variance + const double moments_i_mean = (stride == 0) + ? theoretical_moments(i) + : std::pow(theoretical_moments(1), i); + const double moments_i_squared = (stride == 0) + ? theoretical_moments(2 * i) + : std::pow(theoretical_moments(2), i); + const double moments_i_var = + moments_i_squared - moments_i_mean * moments_i_mean; + + // assume every operation has a small numerical error. + static const double kNumericalError = 1e-6; + // it takes i multiplications to calculate one i-th moment. + const double error_per_moment = i * kNumericalError; + const double total_variance = + moments_i_var / moments_sample_count[i] + error_per_moment; + // z_test is approximately a unit normal distribution. + const double z_test = + fabs((moments[i] - moments_i_mean) / sqrt(total_variance)); + + if (z_test > z_limit) { + LOG(ERROR) << "failing z_test:" + << " moment: " << i << " stride: " << stride + << " z_test: " << z_test << " z_limit: " << z_limit + << " measured moments: " << moments[i] + << " theoretical mean of the moments: " << moments_i_mean + << " theoretical var of the moments: " << moments_i_var + << " sample count: " << moments_sample_count[i]; + status = false; + } + } + + return status; +} + +// This tests checks that the generated samples match the theoretical moments +// of the uniform distribution. +template +void UniformMomentsTest(int count, int max_moments, + const std::vector& strides, T z_limit) { + auto uniform_moments = [](int n) -> double { return 1. / (n + 1); }; + + std::vector v1(count); + uint64 seed = GetTestSeed(); + PhiloxRandom gen(seed); + FillRandoms >(gen, &v1[0], v1.size()); + for (int stride : strides) { + bool status = CheckSamplesMoments(v1, uniform_moments, max_moments, + stride, z_limit); + ASSERT_TRUE(status) << " UniformMomentsTest failing. seed: " << seed; + } +} + +// This test checks that the generated samples match the theoretical moments +// of the unit normal distribution. +template +void NormalMomentsTest(int count, int max_moments, + const std::vector& strides, T z_limit) { + auto normal_moments = [](int n) -> double { + if (n % 2 == 1) { + // For an odd order, the moment of a unit normal distribution is zero. + return 0.; + } else { + // For an even order, the moment of a unit normal distribution is. + // (n-1)!! + double v = 1.; + for (int i = n - 1; i >= 1; i -= 2) { + v *= i; + } + return v; + } + }; + + std::vector v1(count); + uint64 seed = GetTestSeed(); + PhiloxRandom gen(seed); + FillRandoms >(gen, &v1[0], v1.size()); + + for (int stride : strides) { + bool status = CheckSamplesMoments(v1, normal_moments, max_moments, + stride, z_limit); + ASSERT_TRUE(status) << " NormalMomentsTest failing. seed: " << seed; + } +} + +// A functor to calculate the moments for the truncated normal distribution. +// For any odd order, the moment is zero. But for any other n, it can be proven +// that the following recursive relationship for the moments of the truncated +// standard normal: +// m(n) = (n - 1) * m(n - 2) - 2 * v ^ (n - 1) * f(v) / (2 * Phi(v) - 1) +// where v is the cut-off value, f(v) is the p.d.f of the standard +// normal, and Phi(v) is the c.d.f of the standard normal. +class TruncatedNormalMoments { + public: + double operator()(int n) { + if (n == 0) { + return 1; + } + if (n % 2 == 1) { + // For an odd order, the moment is always zero + return 0.; + } + + // Memoization and check the cached results. + auto iter = cached_results_.find(n); + if (iter != cached_results_.end()) { + return iter->second; + } + + // The real computation of the moment. + double bias = 2.0 * std::pow(kV, n - 1) * kFV / (2.0 * kPhiV - 1.0); + double moment_n_minus_2 = (*this)(n - 2); + double moment_n = (n - 1) * moment_n_minus_2 - bias; + + cached_results_[n] = moment_n; + return moment_n; + } + + private: + const double kV = 2.0; + // f(v), where f is the p.d.f of the normal distribution and v=2. + const double kFV = 1.0 / sqrt(2.0 * M_PI) * exp(-kV * kV / 2.0); + // The numerical evaluation of Phi(v), where v is the truncate value. + // v = 2 in the current implementation. + const double kPhiV = 0.977249868051821; + std::unordered_map cached_results_; +}; + +// This test checks that the generated samples matche the theoretical moments +// of the truncated normal distribution. +template +void RandomParametersMomentsTest(int count, int max_moments, + const std::vector& strides, T z_limit) { + std::vector v1(count); + uint64 seed = GetTestSeed(); + PhiloxRandom gen(seed); + FillRandomsWithSingles< + TruncatedNormalDistribution, T> >( + gen, &v1[0], v1.size()); + + for (int stride : strides) { + bool status = CheckSamplesMoments(v1, TruncatedNormalMoments(), + max_moments, stride, z_limit); + ASSERT_TRUE(status) << " NormalMomentsTest failing. seed: " << seed; + } +} + +TEST(PhiloxRandomTest, UniformFloatMomentsTest) { + const std::vector strides = {0, 1, 4, 17}; + UniformMomentsTest(1 << 20, 40, strides, kZLimit); +} + +TEST(PhiloxRandomTest, NormalFloatMomentsTest) { + const std::vector strides = {0, 1, 4, 17}; + NormalMomentsTest(8 << 20, 25, strides, kZLimit); +} + +TEST(PhiloxRandomTest, RandomParametersFloatMomentsTest) { + const std::vector strides = {0, 1, 4, 17}; + RandomParametersMomentsTest(1 << 20, 40, strides, kZLimit); +} + +TEST(PhiloxRandomTest, UniformDoubleMomentsTest) { + const std::vector strides = {0, 1, 4, 17}; + UniformMomentsTest(1 << 20, 40, strides, kZLimit); +} + +TEST(PhiloxRandomTest, NormalDoubleMomentsTest) { + const std::vector strides = {0, 1, 4, 17}; + NormalMomentsTest(8 << 20, 25, strides, kZLimit); +} + +TEST(PhiloxRandomTest, RandomParametersDoubleMomentsTest) { + const std::vector strides = {0, 1, 4, 17}; + RandomParametersMomentsTest(1 << 20, 40, strides, kZLimit); +} + +} // namespace +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/random_test.cc b/tensorflow/core/lib/random/random_test.cc new file mode 100644 index 0000000000..7ed37c8b5e --- /dev/null +++ b/tensorflow/core/lib/random/random_test.cc @@ -0,0 +1,21 @@ +#include "tensorflow/core/lib/random/random.h" + +#include +#include "tensorflow/core/platform/port.h" +#include + +namespace tensorflow { +namespace random { +namespace { + +TEST(New64Test, SanityCheck) { + std::set values; + for (int i = 0; i < 1000000; i++) { + uint64 x = New64(); + EXPECT_TRUE(values.insert(x).second) << "duplicate " << x; + } +} + +} // namespace +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/simple_philox.cc b/tensorflow/core/lib/random/simple_philox.cc new file mode 100644 index 0000000000..1035e1f017 --- /dev/null +++ b/tensorflow/core/lib/random/simple_philox.cc @@ -0,0 +1,24 @@ +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/random/exact_uniform_int.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace random { + +uint32 SimplePhilox::Uniform(uint32 n) { + return ExactUniformInt(n, [this]() { return Rand32(); }); +} + +uint64 SimplePhilox::Uniform64(uint64 n) { + return ExactUniformInt(n, [this]() { return Rand64(); }); +} + +uint32 SimplePhilox::Skewed(int max_log) { + CHECK(0 <= max_log && max_log <= 32); + const int shift = Rand32() % (max_log + 1); + const uint32 mask = shift == 32 ? ~static_cast(0) : (1 << shift) - 1; + return Rand32() & mask; +} + +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/simple_philox.h b/tensorflow/core/lib/random/simple_philox.h new file mode 100644 index 0000000000..12b15d7616 --- /dev/null +++ b/tensorflow/core/lib/random/simple_philox.h @@ -0,0 +1,61 @@ +#ifndef TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_ +#define TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_ + +#include +#include +#include + +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { +namespace random { + +// A simple imperative interface to Philox +class SimplePhilox { + public: + PHILOX_DEVICE_INLINE + explicit SimplePhilox(PhiloxRandom* gen) : single_(gen) {} + + // 32 random bits + PHILOX_DEVICE_INLINE uint32 Rand32() { return single_(); } + + // 64 random bits + PHILOX_DEVICE_INLINE uint64 Rand64() { + const uint32 lo = single_(), hi = single_(); + return lo | static_cast(hi) << 32; + } + + // Uniform float in [0, 1) + PHILOX_DEVICE_INLINE float RandFloat() { return Uint32ToFloat(single_()); } + + // Uniform double in [0, 1) + PHILOX_DEVICE_INLINE double RandDouble() { + const uint32 x0 = single_(), x1 = single_(); + return Uint64ToDouble(x0, x1); + } + + // Uniform integer in [0, n). + // Uses rejection sampling, so may need more than one 32-bit sample. + uint32 Uniform(uint32 n); + + // Approximately uniform integer in [0, n). + // Uses rejection sampling, so may need more than one 64-bit sample. + uint64 Uniform64(uint64 n); + + // True with probability 1/n. + bool OneIn(uint32 n) { return Uniform(n) == 0; } + + // Skewed: pick "base" uniformly from range [0,max_log] and then + // return "base" random bits. The effect is to pick a number in the + // range [0,2^max_log-1] with bias towards smaller numbers. + uint32 Skewed(int max_log); + + private: + SingleSampleAdapter single_; +}; + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_ diff --git a/tensorflow/core/lib/random/simple_philox_test.cc b/tensorflow/core/lib/random/simple_philox_test.cc new file mode 100644 index 0000000000..4246b8b4dd --- /dev/null +++ b/tensorflow/core/lib/random/simple_philox_test.cc @@ -0,0 +1,120 @@ +#include "tensorflow/core/lib/random/simple_philox.h" + +#include +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include + +namespace tensorflow { +namespace random { +namespace { + +TEST(SimplePhiloxTest, FloatTest) { + PhiloxRandom philox(7, 7); + SimplePhilox gen(&philox); + static const int kIters = 1000000; + for (int i = 0; i < kIters; ++i) { + float f = gen.RandFloat(); + EXPECT_LE(0.0f, f); + EXPECT_GT(1.0f, f); + } + for (int i = 0; i < kIters; ++i) { + double d = gen.RandDouble(); + EXPECT_LE(0.0, d); + EXPECT_GT(1.0, d); + } +} + +static void DifferenceTest(const char *names, SimplePhilox *gen1, + SimplePhilox *gen2) { + static const int kIters = 100; + bool different = false; + for (int i = 0; i < kIters; ++i) { + if (gen1->Rand32() != gen2->Rand32()) { + different = true; + break; + } + } + CHECK(different) << "different seeds but same output!"; +} + +TEST(SimplePhiloxTest, DifferenceTest) { + PhiloxRandom philox1(1, 1), philox2(17, 17); + SimplePhilox gen1(&philox1), gen2(&philox2); + + DifferenceTest("SimplePhilox: different seeds", &gen1, &gen2); +} + +TEST(SimplePhiloxTest, DifferenceTestCloseSeeds) { + PhiloxRandom philox1(1, 1), philox2(2, 1); + SimplePhilox gen1(&philox1), gen2(&philox2); + + DifferenceTest("SimplePhilox: close seeds", &gen1, &gen2); +} + +TEST(SimplePhiloxTest, Regression_CloseSeedsAreDifferent) { + const int kCount = 1000; + + // Two seeds differ only by the last bit. + PhiloxRandom philox1(0, 1), philox2(1, 1); + SimplePhilox gen1(&philox1), gen2(&philox2); + + std::set first; + std::set all; + for (int i = 0; i < kCount; ++i) { + uint32 v = gen1.Rand32(); + first.insert(v); + all.insert(v); + all.insert(gen2.Rand32()); + } + + // Broken array initialization implementation (before 2009-08-18) using the + // above seeds return <1000, 1007>, generating output that is >99% similar. + // The fix returns <1000, 2000> for completely disjoint sets. + EXPECT_EQ(kCount, first.size()); + EXPECT_EQ(2 * kCount, all.size()); +} + +TEST(SimplePhiloxTest, TestUniform) { + PhiloxRandom philox(17, 17); + SimplePhilox gen(&philox); + + uint32 range = 3 * (1L << 29); + uint32 threshold = 1L << 30; + + size_t count = 0; + static const int kTrials = 100000; + for (int i = 0; i < kTrials; ++i) { + uint32 rnd = gen.Uniform(range); + if (rnd < threshold) { + ++count; + } + } + + EXPECT_LT(fabs((threshold + 0.0) / range - (count + 0.0) / kTrials), 0.005); +} + +TEST(SimplePhiloxTest, TestUniform64) { + PhiloxRandom philox(17, 17); + SimplePhilox gen(&philox); + + uint64 range = 3 * (1LL << 59); + uint64 threshold = 1LL << 60; + + size_t count = 0; + static const int kTrials = 100000; + for (int i = 0; i < kTrials; ++i) { + uint64 rnd = gen.Uniform64(range); + if (rnd < threshold) { + ++count; + } + } + + EXPECT_LT(fabs((threshold + 0.0) / range - (count + 0.0) / kTrials), 0.005); +} + +} // namespace +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/weighted_picker.cc b/tensorflow/core/lib/random/weighted_picker.cc new file mode 100644 index 0000000000..f96da578ec --- /dev/null +++ b/tensorflow/core/lib/random/weighted_picker.cc @@ -0,0 +1,203 @@ +#include "tensorflow/core/lib/random/weighted_picker.h" + +#include +#include + +#include "tensorflow/core/lib/random/simple_philox.h" + +namespace tensorflow { +namespace random { + +WeightedPicker::WeightedPicker(int N) { + CHECK_GE(N, 0); + N_ = N; + + // Find the number of levels + num_levels_ = 1; + while (LevelSize(num_levels_ - 1) < N) { + num_levels_++; + } + + // Initialize the levels + level_ = new int32*[num_levels_]; + for (int l = 0; l < num_levels_; l++) { + level_[l] = new int32[LevelSize(l)]; + } + + SetAllWeights(1); +} + +WeightedPicker::~WeightedPicker() { + for (int l = 0; l < num_levels_; l++) { + delete[] level_[l]; + } + delete[] level_; +} + +static int32 UnbiasedUniform(SimplePhilox* r, int32 n) { + CHECK_LE(0, n); + const uint32 range = ~static_cast(0); + if (n == 0) { + return r->Rand32() * n; + } else if (0 == (n & (n - 1))) { + // N is a power of two, so just mask off the lower bits. + return r->Rand32() & (n - 1); + } else { + // Reject all numbers that skew the distribution towards 0. + + // Rand32's output is uniform in the half-open interval [0, 2^{32}). + // For any interval [m,n), the number of elements in it is n-m. + + uint32 rem = (range % n) + 1; + uint32 rnd; + + // rem = ((2^{32}-1) \bmod n) + 1 + // 1 <= rem <= n + + // NB: rem == n is impossible, since n is not a power of 2 (from + // earlier check). + + do { + rnd = r->Rand32(); // rnd uniform over [0, 2^{32}) + } while (rnd < rem); // reject [0, rem) + // rnd is uniform over [rem, 2^{32}) + // + // The number of elements in the half-open interval is + // + // 2^{32} - rem = 2^{32} - ((2^{32}-1) \bmod n) - 1 + // = 2^{32}-1 - ((2^{32}-1) \bmod n) + // = n \cdot \lfloor (2^{32}-1)/n \rfloor + // + // therefore n evenly divides the number of integers in the + // interval. + // + // The function v \rightarrow v % n takes values from [bias, + // 2^{32}) to [0, n). Each integer in the range interval [0, n) + // will have exactly \lfloor (2^{32}-1)/n \rfloor preimages from + // the domain interval. + // + // Therefore, v % n is uniform over [0, n). QED. + + return rnd % n; + } +} + +int WeightedPicker::Pick(SimplePhilox* rnd) const { + if (total_weight() == 0) return -1; + + // using unbiased uniform distribution to avoid bias + // toward low elements resulting from a possible use + // of big weights. + return PickAt(UnbiasedUniform(rnd, total_weight())); +} + +int WeightedPicker::PickAt(int32 weight_index) const { + if (weight_index < 0 || weight_index >= total_weight()) return -1; + + int32 position = weight_index; + int index = 0; + + for (int l = 1; l < num_levels_; l++) { + // Pick left or right child of "level_[l-1][index]" + const int32 left_weight = level_[l][2 * index]; + if (position < left_weight) { + // Descend to left child + index = 2 * index; + } else { + // Descend to right child + index = 2 * index + 1; + position -= left_weight; + } + } + CHECK_GE(index, 0); + CHECK_LT(index, N_); + CHECK_LE(position, level_[num_levels_ - 1][index]); + return index; +} + +void WeightedPicker::set_weight(int index, int32 weight) { + assert(index >= 0); + assert(index < N_); + + // Adjust the sums all the way up to the root + const int32 delta = weight - get_weight(index); + for (int l = num_levels_ - 1; l >= 0; l--) { + level_[l][index] += delta; + index >>= 1; + } +} + +void WeightedPicker::SetAllWeights(int32 weight) { + // Initialize leaves + int32* leaves = level_[num_levels_ - 1]; + for (int i = 0; i < N_; i++) leaves[i] = weight; + for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0; + + // Now sum up towards the root + RebuildTreeWeights(); +} + +void WeightedPicker::SetWeightsFromArray(int N, const int32* weights) { + Resize(N); + + // Initialize leaves + int32* leaves = level_[num_levels_ - 1]; + for (int i = 0; i < N_; i++) leaves[i] = weights[i]; + for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0; + + // Now sum up towards the root + RebuildTreeWeights(); +} + +void WeightedPicker::RebuildTreeWeights() { + for (int l = num_levels_ - 2; l >= 0; l--) { + int32* level = level_[l]; + int32* children = level_[l + 1]; + for (int i = 0; i < LevelSize(l); i++) { + level[i] = children[2 * i] + children[2 * i + 1]; + } + } +} + +void WeightedPicker::Append(int32 weight) { + Resize(num_elements() + 1); + set_weight(num_elements() - 1, weight); +} + +void WeightedPicker::Resize(int new_size) { + CHECK_GE(new_size, 0); + if (new_size <= LevelSize(num_levels_ - 1)) { + // The new picker fits in the existing levels. + + // First zero out any of the weights that are being dropped so + // that the levels are correct (only needed when shrinking) + for (int i = new_size; i < N_; i++) { + set_weight(i, 0); + } + + // We do not need to set any new weights when enlarging because + // the unneeded entries always have weight zero. + N_ = new_size; + return; + } + + // We follow the simple strategy of just copying the old + // WeightedPicker into a new WeightedPicker. The cost is + // O(N) regardless. + assert(new_size > N_); + WeightedPicker new_picker(new_size); + int32* dst = new_picker.level_[new_picker.num_levels_ - 1]; + int32* src = this->level_[this->num_levels_ - 1]; + memcpy(dst, src, sizeof(dst[0]) * N_); + memset(dst + N_, 0, sizeof(dst[0]) * (new_size - N_)); + new_picker.RebuildTreeWeights(); + + // Now swap the two pickers + std::swap(new_picker.N_, this->N_); + std::swap(new_picker.num_levels_, this->num_levels_); + std::swap(new_picker.level_, this->level_); + assert(this->N_ == new_size); +} + +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/random/weighted_picker.h b/tensorflow/core/lib/random/weighted_picker.h new file mode 100644 index 0000000000..3d2c2dbb39 --- /dev/null +++ b/tensorflow/core/lib/random/weighted_picker.h @@ -0,0 +1,118 @@ + +// An abstraction to pick from one of N elements with a specified +// weight per element. +// +// The weight for a given element can be changed in O(lg N) time +// An element can be picked in O(lg N) time. +// +// Uses O(N) bytes of memory. +// +// Alternative: distribution-sampler.h allows O(1) time picking, but no weight +// adjustment after construction. + +#ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_ +#define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_ + +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace random { + +class SimplePhilox; + +class WeightedPicker { + public: + // REQUIRES N >= 0 + // Initializes the elements with a weight of one per element + explicit WeightedPicker(int N); + + // Releases all resources + ~WeightedPicker(); + + // Pick a random element with probability proportional to its weight. + // If total weight is zero, returns -1. + int Pick(SimplePhilox* rnd) const; + + // Deterministically pick element x whose weight covers the + // specified weight_index. + // Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ] + int PickAt(int32 weight_index) const; + + // Get the weight associated with an element + // REQUIRES 0 <= index < N + int32 get_weight(int index) const; + + // Set the weight associated with an element + // REQUIRES weight >= 0.0f + // REQUIRES 0 <= index < N + void set_weight(int index, int32 weight); + + // Get the total combined weight of all elements + int32 total_weight() const; + + // Get the number of elements in the picker + int num_elements() const; + + // Set weight of each element to "weight" + void SetAllWeights(int32 weight); + + // Resizes the picker to N and + // sets the weight of each element i to weight[i]. + // The sum of the weights should not exceed 2^31 - 2 + // Complexity O(N). + void SetWeightsFromArray(int N, const int32* weights); + + // REQUIRES N >= 0 + // + // Resize the weighted picker so that it has "N" elements. + // Any newly added entries have zero weight. + // + // Note: Resizing to a smaller size than num_elements() will + // not reclaim any memory. If you wish to reduce memory usage, + // allocate a new WeightedPicker of the appropriate size. + // + // It is efficient to use repeated calls to Resize(num_elements() + 1) + // to grow the picker to size X (takes total time O(X)). + void Resize(int N); + + // Grow the picker by one and set the weight of the new entry to "weight". + // + // Repeated calls to Append() in order to grow the + // picker to size X takes a total time of O(X lg(X)). + // Consider using SetWeightsFromArray instead. + void Append(int32 weight); + + private: + // We keep a binary tree with N leaves. The "i"th leaf contains + // the weight of the "i"th element. An internal node contains + // the sum of the weights of its children. + int N_; // Number of elements + int num_levels_; // Number of levels in tree (level-0 is root) + int32** level_; // Array that holds nodes per level + + // Size of each level + static int LevelSize(int level) { return 1 << level; } + + // Rebuild the tree weights using the leaf weights + void RebuildTreeWeights(); + + TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker); +}; + +inline int32 WeightedPicker::get_weight(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, N_); + return level_[num_levels_ - 1][index]; +} + +inline int32 WeightedPicker::total_weight() const { return level_[0][0]; } + +inline int WeightedPicker::num_elements() const { return N_; } + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_ diff --git a/tensorflow/core/lib/random/weighted_picker_test.cc b/tensorflow/core/lib/random/weighted_picker_test.cc new file mode 100644 index 0000000000..0b27d437d5 --- /dev/null +++ b/tensorflow/core/lib/random/weighted_picker_test.cc @@ -0,0 +1,254 @@ +#include "tensorflow/core/lib/random/weighted_picker.h" + +#include +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include + +namespace tensorflow { +namespace random { + +static void TestPicker(SimplePhilox* rnd, int size); +static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, int trials); +static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials); +static void TestPickAt(int items, const int32* weights); + +TEST(WeightedPicker, Simple) { + PhiloxRandom philox(testing::RandomSeed(), 17); + SimplePhilox rnd(&philox); + + { + VLOG(0) << "======= Zero-length picker"; + WeightedPicker picker(0); + EXPECT_EQ(picker.Pick(&rnd), -1); + } + + { + VLOG(0) << "======= Singleton picker"; + WeightedPicker picker(1); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + } + + { + VLOG(0) << "======= Grown picker"; + WeightedPicker picker(0); + for (int i = 0; i < 10; i++) { + picker.Append(1); + } + CheckUniform(&rnd, &picker, 100000); + } + + { + VLOG(0) << "======= Grown picker with zero weights"; + WeightedPicker picker(1); + picker.Resize(10); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + } + + { + VLOG(0) << "======= Shrink picker and check weights"; + WeightedPicker picker(1); + picker.Resize(10); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + EXPECT_EQ(picker.Pick(&rnd), 0); + for (int i = 0; i < 10; i++) { + picker.set_weight(i, i); + } + EXPECT_EQ(picker.total_weight(), 45); + picker.Resize(5); + EXPECT_EQ(picker.total_weight(), 10); + picker.Resize(2); + EXPECT_EQ(picker.total_weight(), 1); + picker.Resize(1); + EXPECT_EQ(picker.total_weight(), 0); + } +} + +TEST(WeightedPicker, BigWeights) { + PhiloxRandom philox(testing::RandomSeed() + 1, 17); + SimplePhilox rnd(&philox); + VLOG(0) << "======= Check uniform with big weights"; + WeightedPicker picker(2); + picker.SetAllWeights(2147483646L / 3); // (2^31 - 2) / 3 + CheckUniform(&rnd, &picker, 100000); +} + +TEST(WeightedPicker, Deterministic) { + VLOG(0) << "======= Testing deterministic pick"; + static const int32 weights[] = {1, 0, 200, 5, 42}; + TestPickAt(TF_ARRAYSIZE(weights), weights); +} + +TEST(WeightedPicker, Randomized) { + PhiloxRandom philox(testing::RandomSeed() + 10, 17); + SimplePhilox rnd(&philox); + TestPicker(&rnd, 1); + TestPicker(&rnd, 2); + TestPicker(&rnd, 3); + TestPicker(&rnd, 4); + TestPicker(&rnd, 7); + TestPicker(&rnd, 8); + TestPicker(&rnd, 9); + TestPicker(&rnd, 10); + TestPicker(&rnd, 100); +} + +static void TestPicker(SimplePhilox* rnd, int size) { + VLOG(0) << "======= Testing size " << size; + + // Check that empty picker returns -1 + { + WeightedPicker picker(size); + picker.SetAllWeights(0); + for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), -1); + } + + // Create zero weights array + std::vector weights(size); + for (int elem = 0; elem < size; elem++) { + weights[elem] = 0; + } + + // Check that singleton picker always returns the same element + for (int elem = 0; elem < size; elem++) { + WeightedPicker picker(size); + picker.SetAllWeights(0); + picker.set_weight(elem, elem + 1); + for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem); + weights[elem] = 10; + picker.SetWeightsFromArray(size, &weights[0]); + for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem); + weights[elem] = 0; + } + + // Check that uniform picker generates elements roughly uniformly + { + WeightedPicker picker(size); + CheckUniform(rnd, &picker, 100000); + } + + // Check uniform picker that was grown piecemeal + if (size / 3 > 0) { + WeightedPicker picker(size / 3); + while (picker.num_elements() != size) { + picker.Append(1); + } + CheckUniform(rnd, &picker, 100000); + } + + // Check that skewed distribution works + if (size <= 10) { + // When picker grows one element at a time + WeightedPicker picker(size); + int32 weight = 1; + for (int elem = 0; elem < size; elem++) { + picker.set_weight(elem, weight); + weights[elem] = weight; + weight *= 2; + } + CheckSkewed(rnd, &picker, 1000000); + + // When picker is created from an array + WeightedPicker array_picker(0); + array_picker.SetWeightsFromArray(size, &weights[0]); + CheckSkewed(rnd, &array_picker, 1000000); + } +} + +static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, + int trials) { + const int size = picker->num_elements(); + int* count = new int[size]; + memset(count, 0, sizeof(count[0]) * size); + for (int i = 0; i < size * trials; i++) { + const int elem = picker->Pick(rnd); + EXPECT_GE(elem, 0); + EXPECT_LT(elem, size); + count[elem]++; + } + const int expected_min = int(0.9 * trials); + const int expected_max = int(1.1 * trials); + for (int i = 0; i < size; i++) { + EXPECT_GE(count[i], expected_min); + EXPECT_LE(count[i], expected_max); + } + delete[] count; +} + +static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials) { + const int size = picker->num_elements(); + int* count = new int[size]; + memset(count, 0, sizeof(count[0]) * size); + for (int i = 0; i < size * trials; i++) { + const int elem = picker->Pick(rnd); + EXPECT_GE(elem, 0); + EXPECT_LT(elem, size); + count[elem]++; + } + + for (int i = 0; i < size - 1; i++) { + LOG(INFO) << i << ": " << count[i]; + const float ratio = float(count[i + 1]) / float(count[i]); + EXPECT_GE(ratio, 1.6f); + EXPECT_LE(ratio, 2.4f); + } + delete[] count; +} + +static void TestPickAt(int items, const int32* weights) { + WeightedPicker picker(items); + picker.SetWeightsFromArray(items, weights); + int weight_index = 0; + for (int i = 0; i < items; ++i) { + for (int j = 0; j < weights[i]; ++j) { + int pick = picker.PickAt(weight_index); + EXPECT_EQ(pick, i); + ++weight_index; + } + } + EXPECT_EQ(weight_index, picker.total_weight()); +} + +static void BM_Create(int iters, int arg) { + while (--iters > 0) { + WeightedPicker p(arg); + } +} +BENCHMARK(BM_Create)->Range(1, 1024); + +static void BM_CreateAndSetWeights(int iters, int arg) { + std::vector weights(arg); + for (int i = 0; i < arg; i++) { + weights[i] = i * 10; + } + while (--iters > 0) { + WeightedPicker p(arg); + p.SetWeightsFromArray(arg, &weights[0]); + } +} +BENCHMARK(BM_CreateAndSetWeights)->Range(1, 1024); + +static void BM_Pick(int iters, int arg) { + PhiloxRandom philox(301, 17); + SimplePhilox rnd(&philox); + WeightedPicker p(arg); + int result = 0; + while (--iters > 0) { + result += p.Pick(&rnd); + } + VLOG(4) << result; // Dummy use +} +BENCHMARK(BM_Pick)->Range(1, 1024); + +} // namespace random +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc new file mode 100644 index 0000000000..d61129fb3f --- /dev/null +++ b/tensorflow/core/lib/strings/numbers.cc @@ -0,0 +1,260 @@ +#include "tensorflow/core/lib/strings/numbers.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace strings { + +char* FastInt32ToBufferLeft(int32 i, char* buffer) { + uint32 u = i; + if (i < 0) { + *buffer++ = '-'; + // We need to do the negation in modular (i.e., "unsigned") + // arithmetic; MSVC++ apprently warns for plain "-u", so + // we write the equivalent expression "0 - u" instead. + u = 0 - u; + } + return FastUInt32ToBufferLeft(u, buffer); +} + +char* FastUInt32ToBufferLeft(uint32 i, char* buffer) { + char* start = buffer; + do { + *buffer++ = ((i % 10) + '0'); + i /= 10; + } while (i > 0); + *buffer = 0; + std::reverse(start, buffer); + return buffer; +} + +char* FastInt64ToBufferLeft(int64 i, char* buffer) { + uint64 u = i; + if (i < 0) { + *buffer++ = '-'; + u = 0 - u; + } + return FastUInt64ToBufferLeft(u, buffer); +} + +char* FastUInt64ToBufferLeft(uint64 i, char* buffer) { + char* start = buffer; + do { + *buffer++ = ((i % 10) + '0'); + i /= 10; + } while (i > 0); + *buffer = 0; + std::reverse(start, buffer); + return buffer; +} + +static const double kDoublePrecisionCheckMax = DBL_MAX / 1.000000000000001; + +char* DoubleToBuffer(double value, char* buffer) { + // DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all + // platforms these days. Just in case some system exists where DBL_DIG + // is significantly larger -- and risks overflowing our buffer -- we have + // this assert. + static_assert(DBL_DIG < 20, "DBL_DIG is too big"); + + bool full_precision_needed = true; + if (std::abs(value) <= kDoublePrecisionCheckMax) { + int snprintf_result = + snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG, value); + + // The snprintf should never overflow because the buffer is significantly + // larger than the precision we asked for. + DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); + + full_precision_needed = strtod(buffer, NULL) != value; + } + + if (full_precision_needed) { + int snprintf_result = + snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG + 2, value); + + // Should never overflow; see above. + DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); + } + return buffer; +} + +bool safe_strto64(const char* str, int64* value) { + if (!str) return false; + + // Skip leading space. + while (isspace(*str)) ++str; + + int64 vlimit = kint64max; + int sign = 1; + if (*str == '-') { + sign = -1; + ++str; + // Different limit for positive and negative integers. + vlimit = kint64min; + } + + if (!isdigit(*str)) return false; + + int64 result = 0; + if (sign == 1) { + do { + int digit = *str - '0'; + if ((vlimit - digit) / 10 < result) { + return false; + } + result = result * 10 + digit; + ++str; + } while (isdigit(*str)); + } else { + do { + int digit = *str - '0'; + if ((vlimit + digit) / 10 > result) { + return false; + } + result = result * 10 - digit; + ++str; + } while (isdigit(*str)); + } + + // Skip trailing space. + while (isspace(*str)) ++str; + + if (*str) return false; + + *value = result; + return true; +} + +bool safe_strto32(const char* str, int32* value) { + if (!str) return false; + + // Skip leading space. + while (isspace(*str)) ++str; + + int64 vmax = kint32max; + int sign = 1; + if (*str == '-') { + sign = -1; + ++str; + // Different max for positive and negative integers. + ++vmax; + } + + if (!isdigit(*str)) return false; + + int64 result = 0; + do { + result = result * 10 + *str - '0'; + if (result > vmax) { + return false; + } + ++str; + } while (isdigit(*str)); + + // Skip trailing space. + while (isspace(*str)) ++str; + + if (*str) return false; + + *value = result * sign; + return true; +} + +bool safe_strtof(const char* str, float* value) { + char* endptr; + *value = strtof(str, &endptr); + while (isspace(*endptr)) ++endptr; + // Ignore range errors from strtod/strtof. + // The values it returns on underflow and + // overflow are the right fallback in a + // robust setting. + return *str != '\0' && *endptr == '\0'; +} + +char* FloatToBuffer(float value, char* buffer) { + // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all + // platforms these days. Just in case some system exists where FLT_DIG + // is significantly larger -- and risks overflowing our buffer -- we have + // this assert. + static_assert(FLT_DIG < 10, "FLT_DIG is too big"); + + int snprintf_result = + snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG, value); + + // The snprintf should never overflow because the buffer is significantly + // larger than the precision we asked for. + DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); + + float parsed_value; + if (!safe_strtof(buffer, &parsed_value) || parsed_value != value) { + snprintf_result = + snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG + 2, value); + + // Should never overflow; see above. + DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); + } + return buffer; +} + +string FpToString(Fprint fp) { + char buf[17]; + snprintf(buf, sizeof(buf), "%016llx", static_cast(fp)); + return string(buf); +} + +bool StringToFp(const string& s, Fprint* fp) { + char junk; + uint64 result; + if (sscanf(s.c_str(), "%llx%c", &result, &junk) == 1) { + *fp = result; + return true; + } else { + return false; + } +} + +string HumanReadableNumBytes(int64 num_bytes) { + if (num_bytes == kint64min) { + // Special case for number with not representable negation. + return "-8E"; + } + + const char* neg_str = (num_bytes < 0) ? "-" : ""; + if (num_bytes < 0) { + num_bytes = -num_bytes; + } + + // Special case for bytes. + if (num_bytes < 1024) { + // No fractions for bytes. + char buf[8]; // Longest possible string is '-XXXXB' + snprintf(buf, sizeof(buf), "%s%lldB", neg_str, + static_cast(num_bytes)); + return string(buf); + } + + static const char units[] = "KMGTPE"; // int64 only goes up to E. + const char* unit = units; + while (num_bytes >= static_cast(1024) * 1024) { + num_bytes /= 1024; + ++unit; + CHECK(unit < units + TF_ARRAYSIZE(units)); + } + + // We use SI prefixes. + char buf[16]; + snprintf(buf, sizeof(buf), ((*unit == 'K') ? "%s%.1f%ciB" : "%s%.2f%ciB"), + neg_str, num_bytes / 1024.0, *unit); + return string(buf); +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h new file mode 100644 index 0000000000..a30a862279 --- /dev/null +++ b/tensorflow/core/lib/strings/numbers.h @@ -0,0 +1,92 @@ +#ifndef TENSORFLOW_LIB_STRINGS_NUMBERS_H_ +#define TENSORFLOW_LIB_STRINGS_NUMBERS_H_ + +#include + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace strings { + +// ---------------------------------------------------------------------- +// FastIntToBufferLeft() +// These are intended for speed. +// +// All functions take the output buffer as an arg. FastInt() uses +// at most 22 bytes, FastTime() uses exactly 30 bytes. They all +// return a pointer to the beginning of the output, which is the same as +// the beginning of the input buffer. +// +// NOTE: In 64-bit land, sizeof(time_t) is 8, so it is possible +// to pass to FastTimeToBuffer() a time whose year cannot be +// represented in 4 digits. In this case, the output buffer +// will contain the string "Invalid:" +// ---------------------------------------------------------------------- + +// Previously documented minimums -- the buffers provided must be at least this +// long, though these numbers are subject to change: +// Int32, UInt32: 12 bytes +// Int64, UInt64, Int, Uint: 22 bytes +// Time: 30 bytes +// Use kFastToBufferSize rather than hardcoding constants. +static const int kFastToBufferSize = 32; + +// ---------------------------------------------------------------------- +// FastInt32ToBufferLeft() +// FastUInt32ToBufferLeft() +// FastInt64ToBufferLeft() +// FastUInt64ToBufferLeft() +// +// These functions convert their numeric argument to an ASCII +// representation of the numeric value in base 10, with the +// representation being left-aligned in the buffer. The caller is +// responsible for ensuring that the buffer has enough space to hold +// the output. The buffer should typically be at least kFastToBufferSize +// bytes. +// +// Returns a pointer to the end of the string (i.e. the null character +// terminating the string). +// ---------------------------------------------------------------------- + +char* FastInt32ToBufferLeft(int32 i, char* buffer); // at least 12 bytes +char* FastUInt32ToBufferLeft(uint32 i, char* buffer); // at least 12 bytes +char* FastInt64ToBufferLeft(int64 i, char* buffer); // at least 22 bytes +char* FastUInt64ToBufferLeft(uint64 i, char* buffer); // at least 22 bytes + +// Required buffer size for DoubleToBuffer is kFastToBufferSize. +// Required buffer size for FloatToBuffer is kFastToBufferSize. +char* DoubleToBuffer(double i, char* buffer); +char* FloatToBuffer(float i, char* buffer); + +// Convert a 64-bit fingerprint value to an ASCII representation. +string FpToString(Fprint fp); + +// Attempt to parse a fingerprint in the form encoded by FpToString. If +// successsful, stores the fingerprint in *fp and returns true. Otherwise, +// returns false. +bool StringToFp(const string& s, Fprint* fp); + +// Convert strings to 32bit integer values. +// Leading and trailing spaces are allowed. +// Return false with overflow or invalid input. +bool safe_strto32(const char* str, int32* value); + +// Convert strings to 64bit integer values. +// Leading and trailing spaces are allowed. +// Return false with overflow or invalid input. +bool safe_strto64(const char* str, int64* value); + +// Convert strings to floating point values. +// Leading and trailing spaces are allowed. +// Values may be rounded on over- and underflow. +bool safe_strtof(const char* str, float* value); + +// Converts from an int64 representing a number of bytes to a +// human readable string representing the same number. +// e.g. 12345678 -> "11.77MiB". +string HumanReadableNumBytes(int64 num_bytes); + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_STRINGS_NUMBERS_H_ diff --git a/tensorflow/core/lib/strings/numbers_test.cc b/tensorflow/core/lib/strings/numbers_test.cc new file mode 100644 index 0000000000..b178e6af53 --- /dev/null +++ b/tensorflow/core/lib/strings/numbers_test.cc @@ -0,0 +1,113 @@ +#include "tensorflow/core/lib/strings/numbers.h" + +#include +#include + +namespace tensorflow { +namespace strings { + +// NOTE: most of the routines in numbers.h are tested indirectly through +// strcat_test.cc in this directory. + +// Test StrCat of ints and longs of various sizes and signdedness. +TEST(FpToString, Ints) { + for (int s = 0; s < 64; s++) { + for (int delta = -1; delta <= 1; delta++) { + uint64 fp = (1ull << s) + delta; + string s = FpToString(fp); + uint64 fp2; + EXPECT_TRUE(StringToFp(s, &fp2)); + EXPECT_EQ(fp, fp2); + } + } + Fprint dummy; + EXPECT_FALSE(StringToFp("", &dummy)); + EXPECT_FALSE(StringToFp("xyz", &dummy)); + EXPECT_FALSE(StringToFp("0000000000000000xyz", &dummy)); +} + +TEST(HumanReadableNumBytes, Bytes) { + EXPECT_EQ("0B", HumanReadableNumBytes(0)); + EXPECT_EQ("4B", HumanReadableNumBytes(4)); + EXPECT_EQ("1023B", HumanReadableNumBytes(1023)); + + EXPECT_EQ("1.0KiB", HumanReadableNumBytes(1024)); + EXPECT_EQ("1.0KiB", HumanReadableNumBytes(1025)); + EXPECT_EQ("1.5KiB", HumanReadableNumBytes(1500)); + EXPECT_EQ("1.9KiB", HumanReadableNumBytes(1927)); + + EXPECT_EQ("2.0KiB", HumanReadableNumBytes(2048)); + EXPECT_EQ("1.00MiB", HumanReadableNumBytes(1 << 20)); + EXPECT_EQ("11.77MiB", HumanReadableNumBytes(12345678)); + EXPECT_EQ("1.00GiB", HumanReadableNumBytes(1 << 30)); + + EXPECT_EQ("1.00TiB", HumanReadableNumBytes(1LL << 40)); + EXPECT_EQ("1.00PiB", HumanReadableNumBytes(1LL << 50)); + EXPECT_EQ("1.00EiB", HumanReadableNumBytes(1LL << 60)); + + // Try a few negative numbers + EXPECT_EQ("-1B", HumanReadableNumBytes(-1)); + EXPECT_EQ("-4B", HumanReadableNumBytes(-4)); + EXPECT_EQ("-1000B", HumanReadableNumBytes(-1000)); + EXPECT_EQ("-11.77MiB", HumanReadableNumBytes(-12345678)); + EXPECT_EQ("-8E", HumanReadableNumBytes(kint64min)); +} + +TEST(safe_strto32, Int32s) { + int32 result; + + EXPECT_EQ(true, safe_strto32("1", &result)); + EXPECT_EQ(1, result); + EXPECT_EQ(true, safe_strto32("123", &result)); + EXPECT_EQ(123, result); + EXPECT_EQ(true, safe_strto32(" -123 ", &result)); + EXPECT_EQ(-123, result); + EXPECT_EQ(true, safe_strto32("2147483647", &result)); + EXPECT_EQ(2147483647, result); + EXPECT_EQ(true, safe_strto32("-2147483648", &result)); + EXPECT_EQ(-2147483648, result); + + // Invalid argument + EXPECT_EQ(false, safe_strto32(" 132as ", &result)); + EXPECT_EQ(false, safe_strto32(" 132.2 ", &result)); + EXPECT_EQ(false, safe_strto32(" -", &result)); + EXPECT_EQ(false, safe_strto32("", &result)); + EXPECT_EQ(false, safe_strto32(" ", &result)); + EXPECT_EQ(false, safe_strto32("123 a", &result)); + + // Overflow + EXPECT_EQ(false, safe_strto32("2147483648", &result)); + EXPECT_EQ(false, safe_strto32("-2147483649", &result)); +} + +TEST(safe_strto64, Int64s) { + int64 result; + + EXPECT_EQ(true, safe_strto64("1", &result)); + EXPECT_EQ(1, result); + EXPECT_EQ(true, safe_strto64("123", &result)); + EXPECT_EQ(123, result); + EXPECT_EQ(true, safe_strto64(" -123 ", &result)); + EXPECT_EQ(-123, result); + EXPECT_EQ(true, safe_strto64("9223372036854775807", &result)); + EXPECT_EQ(9223372036854775807, result); + EXPECT_EQ(true, safe_strto64("-9223372036854775808", &result)); + // kint64min == -9223372036854775808 + // Use -9223372036854775808 directly results in out of range error + EXPECT_EQ(kint64min, result); + + // Invalid argument + EXPECT_EQ(false, safe_strto64(" 132as ", &result)); + EXPECT_EQ(false, safe_strto64(" 132.2 ", &result)); + EXPECT_EQ(false, safe_strto64(" -", &result)); + EXPECT_EQ(false, safe_strto64("", &result)); + EXPECT_EQ(false, safe_strto64(" ", &result)); + EXPECT_EQ(false, safe_strto64("123 a", &result)); + + // Overflow + EXPECT_EQ(false, safe_strto64("9223372036854775808", &result)); + EXPECT_EQ(false, safe_strto64("-9223372036854775809", &result)); +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/ordered_code.cc b/tensorflow/core/lib/strings/ordered_code.cc new file mode 100644 index 0000000000..ec67595ebb --- /dev/null +++ b/tensorflow/core/lib/strings/ordered_code.cc @@ -0,0 +1,515 @@ +#include "tensorflow/core/lib/strings/ordered_code.h" + +#include +#include + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace strings { + +// We encode a string in different ways depending on whether the item +// should be in lexicographically increasing or decreasing order. +// +// +// Lexicographically increasing order +// +// We want a string-to-string mapping F(x) such that for any two strings +// +// x < y => F(x) < F(y) +// +// In addition to the normal characters '\x00' through '\xff', we want to +// encode a few extra symbols in strings: +// +// Separator between items +// Infinite string +// +// Therefore we need an alphabet with at least 258 symbols. Each +// character '\1' through '\xfe' is mapped to itself. The other four are +// encoded into two-letter sequences starting with '\0' and '\xff': +// +// encoded as => \0\1 +// \0 encoded as => \0\xff +// \xff encoded as => \xff\x00 +// encoded as => \xff\xff +// +// The remaining two-letter sequences starting with '\0' and '\xff' are +// currently unused. +// +// F() is defined above. For any finite string x, F(x) is the +// the encodings of x's characters followed by the encoding for . The +// ordering of two finite strings is the same as the ordering of the +// respective characters at the first position where they differ, which in +// turn is the same as the ordering of the encodings of those two +// characters. Moreover, for every finite string x, F(x) < F(). +// +// +// Lexicographically decreasing order +// +// We want a string-to-string mapping G(x) such that for any two strings, +// whether finite or not, +// +// x < y => G(x) > G(y) +// +// To achieve this, define G(x) to be the inversion of F(x): I(F(x)). In +// other words, invert every bit in F(x) to get G(x). For example, +// +// x = \x00\x13\xff +// F(x) = \x00\xff\x13\xff\x00\x00\x01 escape \0, \xff, append F() +// G(x) = \xff\x00\xec\x00\xff\xff\xfe invert every bit in F(x) +// +// x = +// F(x) = \xff\xff +// G(x) = \x00\x00 +// +// Another example is +// +// x F(x) G(x) = I(F(x)) +// - ---- -------------- +// \xff\xff \x00\x00 +// "foo" foo\0\1 \x99\x90\x90\xff\xfe +// "aaa" aaa\0\1 \x9e\x9e\x9e\xff\xfe +// "aa" aa\0\1 \x9e\x9e\xff\xfe +// "" \0\1 \xff\xfe +// +// More generally and rigorously, if for any two strings x and y +// +// F(x) < F(y) => I(F(x)) > I(F(y)) (1) +// +// it would follow that x < y => G(x) > G(y) because +// +// x < y => F(x) < F(y) => G(x) = I(F(x)) > I(F(y)) = G(y) +// +// We now show why (1) is true, in two parts. Notice that for any two +// strings x < y, F(x) is *not* a proper prefix of F(y). Suppose x is a +// proper prefix of y (say, x="abc" < y="abcd"). F(x) and F(y) diverge at +// the F() in F(x) (v. F('d') in the example). Suppose x is not a +// proper prefix of y (say, x="abce" < y="abd"), F(x) and F(y) diverge at +// their respective encodings of the characters where x and y diverge +// (F('c') v. F('d')). Finally, if y=, we can see that +// F(y)=\xff\xff is not the prefix of F(x) for any finite string x, simply +// by considering all the possible first characters of F(x). +// +// Given that F(x) is not a proper prefix F(y), the order of F(x) and F(y) +// is determined by the byte where F(x) and F(y) diverge. For example, the +// order of F(x)="eefh" and F(y)="eeg" is determined by their third +// characters. I(p) inverts each byte in p, which effectively subtracts +// each byte from 0xff. So, in this example, I('f') > I('g'), and thus +// I(F(x)) > I(F(y)). +// +// +// Implementation +// +// To implement G(x) efficiently, we use C++ template to instantiate two +// versions of the code to produce F(x), one for normal encoding (giving us +// F(x)) and one for inverted encoding (giving us G(x) = I(F(x))). + +static const char kEscape1 = '\000'; +static const char kNullCharacter = '\xff'; // Combined with kEscape1 +static const char kSeparator = '\001'; // Combined with kEscape1 + +static const char kEscape2 = '\xff'; +static const char kInfinity = '\xff'; // Combined with kEscape2 +static const char kFFCharacter = '\000'; // Combined with kEscape2 + +static const char kEscape1_Separator[2] = {kEscape1, kSeparator}; + +// Append to "*dest" the "len" bytes starting from "*src". +inline static void AppendBytes(string* dest, const char* src, int len) { + dest->append(src, len); +} + +inline bool IsSpecialByte(char c) { return ((unsigned char)(c + 1)) < 2; } + +// Return a pointer to the first byte in the range "[start..limit)" +// whose value is 0 or 255 (kEscape1 or kEscape2). If no such byte +// exists in the range, returns "limit". +inline const char* SkipToNextSpecialByte(const char* start, const char* limit) { + // If these constants were ever changed, this routine needs to change + DCHECK_EQ(kEscape1, 0); + DCHECK_EQ(kEscape2 & 0xffu, 255u); + const char* p = start; + while (p < limit && !IsSpecialByte(*p)) { + p++; + } + return p; +} + +// Expose SkipToNextSpecialByte for testing purposes +const char* OrderedCode::TEST_SkipToNextSpecialByte(const char* start, + const char* limit) { + return SkipToNextSpecialByte(start, limit); +} + +// Helper routine to encode "s" and append to "*dest", escaping special +// characters. +inline static void EncodeStringFragment(string* dest, StringPiece s) { + const char* p = s.data(); + const char* limit = p + s.size(); + const char* copy_start = p; + while (true) { + p = SkipToNextSpecialByte(p, limit); + if (p >= limit) break; // No more special characters that need escaping + char c = *(p++); + DCHECK(IsSpecialByte(c)); + if (c == kEscape1) { + AppendBytes(dest, copy_start, p - copy_start - 1); + dest->push_back(kEscape1); + dest->push_back(kNullCharacter); + copy_start = p; + } else { + assert(c == kEscape2); + AppendBytes(dest, copy_start, p - copy_start - 1); + dest->push_back(kEscape2); + dest->push_back(kFFCharacter); + copy_start = p; + } + } + if (p > copy_start) { + AppendBytes(dest, copy_start, p - copy_start); + } +} + +void OrderedCode::WriteString(string* dest, StringPiece s) { + EncodeStringFragment(dest, s); + AppendBytes(dest, kEscape1_Separator, 2); +} + +void OrderedCode::WriteNumIncreasing(string* dest, uint64 val) { + // Values are encoded with a single byte length prefix, followed + // by the actual value in big-endian format with leading 0 bytes + // dropped. + unsigned char buf[9]; // 8 bytes for value plus one byte for length + int len = 0; + while (val > 0) { + len++; + buf[9 - len] = (val & 0xff); + val >>= 8; + } + buf[9 - len - 1] = (unsigned char)len; + len++; + AppendBytes(dest, reinterpret_cast(buf + 9 - len), len); +} + +// Parse the encoding of a previously encoded string. +// If parse succeeds, return true, consume encoding from +// "*src", and if result != NULL append the decoded string to "*result". +// Otherwise, return false and leave both undefined. +inline static bool ReadStringInternal(StringPiece* src, string* result) { + const char* start = src->data(); + const char* string_limit = src->data() + src->size(); + + // We only scan up to "limit-2" since a valid string must end with + // a two character terminator: 'kEscape1 kSeparator' + const char* limit = string_limit - 1; + const char* copy_start = start; + while (true) { + start = SkipToNextSpecialByte(start, limit); + if (start >= limit) break; // No terminator sequence found + const char c = *(start++); + // If inversion is required, instead of inverting 'c', we invert the + // character constants to which 'c' is compared. We get the same + // behavior but save the runtime cost of inverting 'c'. + DCHECK(IsSpecialByte(c)); + if (c == kEscape1) { + if (result) { + AppendBytes(result, copy_start, start - copy_start - 1); + } + // kEscape1 kSeparator ends component + // kEscape1 kNullCharacter represents '\0' + const char next = *(start++); + if (next == kSeparator) { + src->remove_prefix(start - src->data()); + return true; + } else if (next == kNullCharacter) { + if (result) { + *result += '\0'; + } + } else { + return false; + } + copy_start = start; + } else { + assert(c == kEscape2); + if (result) { + AppendBytes(result, copy_start, start - copy_start - 1); + } + // kEscape2 kFFCharacter represents '\xff' + // kEscape2 kInfinity is an error + const char next = *(start++); + if (next == kFFCharacter) { + if (result) { + *result += '\xff'; + } + } else { + return false; + } + copy_start = start; + } + } + return false; +} + +bool OrderedCode::ReadString(StringPiece* src, string* result) { + return ReadStringInternal(src, result); +} + +bool OrderedCode::ReadNumIncreasing(StringPiece* src, uint64* result) { + if (src->empty()) { + return false; // Not enough bytes + } + + // Decode length byte + const size_t len = static_cast((*src)[0]); + + // If len > 0 and src is longer than 1, the first byte of "payload" + // must be non-zero (otherwise the encoding is not minimal). + // In opt mode, we don't enforce that encodings must be minimal. + DCHECK(0 == len || src->size() == 1 || (*src)[1] != '\0') + << "invalid encoding"; + + if (len + 1 > src->size() || len > 8) { + return false; // Not enough bytes or too many bytes + } + + if (result) { + uint64 tmp = 0; + for (size_t i = 0; i < len; i++) { + tmp <<= 8; + tmp |= static_cast((*src)[1 + i]); + } + *result = tmp; + } + src->remove_prefix(len + 1); + return true; +} + +void OrderedCode::TEST_Corrupt(string* str, int k) { + int seen_seps = 0; + for (size_t i = 0; i + 1 < str->size(); i++) { + if ((*str)[i] == kEscape1 && (*str)[i + 1] == kSeparator) { + seen_seps++; + if (seen_seps == k) { + (*str)[i + 1] = kSeparator + 1; + return; + } + } + } +} + +// Signed number encoding/decoding ///////////////////////////////////// +// +// The format is as follows: +// +// The first bit (the most significant bit of the first byte) +// represents the sign, 0 if the number is negative and +// 1 if the number is >= 0. +// +// Any unbroken sequence of successive bits with the same value as the sign +// bit, up to 9 (the 8th and 9th are the most significant bits of the next +// byte), are size bits that count the number of bytes after the first byte. +// That is, the total length is between 1 and 10 bytes. +// +// The value occupies the bits after the sign bit and the "size bits" +// till the end of the string, in network byte order. If the number +// is negative, the bits are in 2-complement. +// +// +// Example 1: number 0x424242 -> 4 byte big-endian hex string 0xf0424242: +// +// +---------------+---------------+---------------+---------------+ +// 1 1 1 1 0 0 0 0 0 1 0 0 0 0 1 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 1 0 +// +---------------+---------------+---------------+---------------+ +// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ +// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// | | | | payload: the remaining bits after the sign and size bits +// | | | | and the delimiter bit, the value is 0x424242 +// | | | | +// | size bits: 3 successive bits with the same value as the sign bit +// | (followed by a delimiter bit with the opposite value) +// | mean that there are 3 bytes after the first byte, 4 total +// | +// sign bit: 1 means that the number is non-negative +// +// Example 2: negative number -0x800 -> 2 byte big-endian hex string 0x3800: +// +// +---------------+---------------+ +// 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 +// +---------------+---------------+ +// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ +// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// | | payload: the remaining bits after the sign and size bits and the +// | | delimiter bit, 2-complement because of the negative sign, +// | | value is ~0x7ff, represents the value -0x800 +// | | +// | size bits: 1 bit with the same value as the sign bit +// | (followed by a delimiter bit with the opposite value) +// | means that there is 1 byte after the first byte, 2 total +// | +// sign bit: 0 means that the number is negative +// +// +// Compared with the simpler unsigned format used for uint64 numbers, +// this format is more compact for small numbers, namely one byte encodes +// numbers in the range [-64,64), two bytes cover the range [-2^13,2^13), etc. +// In general, n bytes encode numbers in the range [-2^(n*7-1),2^(n*7-1)). +// (The cross-over point for compactness of representation is 8 bytes, +// where this format only covers the range [-2^55,2^55), +// whereas an encoding with sign bit and length in the first byte and +// payload in all following bytes would cover [-2^56,2^56).) + +static const int kMaxSigned64Length = 10; + +// This array maps encoding length to header bits in the first two bytes. +static const char kLengthToHeaderBits[1 + kMaxSigned64Length][2] = { + {0, 0}, {'\x80', 0}, {'\xc0', 0}, {'\xe0', 0}, + {'\xf0', 0}, {'\xf8', 0}, {'\xfc', 0}, {'\xfe', 0}, + {'\xff', 0}, {'\xff', '\x80'}, {'\xff', '\xc0'}}; + +// This array maps encoding lengths to the header bits that overlap with +// the payload and need fixing when reading. +static const uint64 kLengthToMask[1 + kMaxSigned64Length] = { + 0ULL, + 0x80ULL, + 0xc000ULL, + 0xe00000ULL, + 0xf0000000ULL, + 0xf800000000ULL, + 0xfc0000000000ULL, + 0xfe000000000000ULL, + 0xff00000000000000ULL, + 0x8000000000000000ULL, + 0ULL}; + +// This array maps the number of bits in a number to the encoding +// length produced by WriteSignedNumIncreasing. +// For positive numbers, the number of bits is 1 plus the most significant +// bit position (the highest bit position in a positive int64 is 63). +// For a negative number n, we count the bits in ~n. +// That is, length = kBitsToLength[Bits::Log2Floor64(n < 0 ? ~n : n) + 1]. +static const int8 kBitsToLength[1 + 63] = { + 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4, + 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 7, + 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 10}; + +#if defined(__GNUC__) +// Returns floor(lg(n)). Returns -1 if n == 0. +static int Log2Floor64(uint64 n) { + return n == 0 ? -1 : 63 ^ __builtin_clzll(n); +} +#else +// Portable slow version +static int Log2Floor32_Portable(uint32 n) { + if (n == 0) return -1; + int log = 0; + uint32 value = n; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32 x = value >> shift; + if (x != 0) { + value = x; + log += shift; + } + } + assert(value == 1); + return log; +} +// Returns floor(lg(n)). Returns -1 if n == 0. +static int Log2Floor64(uint64 n) { + const uint32 topbits = static_cast(n >> 32); + if (topbits == 0) { + // Top bits are zero, so scan in bottom bits + return Log2Floor32_Portable(static_cast(n)); + } else { + return 32 + Log2Floor32_Portable(topbits); + } +} +#endif + +// Calculates the encoding length in bytes of the signed number n. +static inline int SignedEncodingLength(int64 n) { + return kBitsToLength[Log2Floor64(n < 0 ? ~n : n) + 1]; +} + +static void StoreBigEndian64(char* dst, uint64 v) { + for (int i = 0; i < 8; i++) { + dst[i] = (v >> (56 - 8 * i)) & 0xff; + } +} + +static uint64 LoadBigEndian64(const char* src) { + uint64 result = 0; + for (int i = 0; i < 8; i++) { + unsigned char c = static_cast(src[i]); + result |= static_cast(c) << (56 - 8 * i); + } + return result; +} + +void OrderedCode::WriteSignedNumIncreasing(string* dest, int64 val) { + const uint64 x = val < 0 ? ~val : val; + if (x < 64) { // fast path for encoding length == 1 + *dest += kLengthToHeaderBits[1][0] ^ val; + return; + } + // buf = val in network byte order, sign extended to 10 bytes + const char sign_byte = val < 0 ? '\xff' : '\0'; + char buf[10] = { + sign_byte, sign_byte, + }; + StoreBigEndian64(buf + 2, val); + static_assert(sizeof(buf) == kMaxSigned64Length, "max length size mismatch"); + const int len = SignedEncodingLength(x); + DCHECK_GE(len, 2); + char* const begin = buf + sizeof(buf) - len; + begin[0] ^= kLengthToHeaderBits[len][0]; + begin[1] ^= kLengthToHeaderBits[len][1]; // ok because len >= 2 + dest->append(begin, len); +} + +bool OrderedCode::ReadSignedNumIncreasing(StringPiece* src, int64* result) { + if (src->empty()) return false; + const uint64 xor_mask = (!((*src)[0] & 0x80)) ? ~0ULL : 0ULL; + const unsigned char first_byte = (*src)[0] ^ (xor_mask & 0xff); + + // now calculate and test length, and set x to raw (unmasked) result + int len; + uint64 x; + if (first_byte != 0xff) { + len = 7 - Log2Floor64(first_byte ^ 0xff); + if (src->size() < static_cast(len)) return false; + x = xor_mask; // sign extend using xor_mask + for (int i = 0; i < len; ++i) + x = (x << 8) | static_cast((*src)[i]); + } else { + len = 8; + if (src->size() < static_cast(len)) return false; + const unsigned char second_byte = (*src)[1] ^ (xor_mask & 0xff); + if (second_byte >= 0x80) { + if (second_byte < 0xc0) { + len = 9; + } else { + const unsigned char third_byte = (*src)[2] ^ (xor_mask & 0xff); + if (second_byte == 0xc0 && third_byte < 0x80) { + len = 10; + } else { + return false; // either len > 10 or len == 10 and #bits > 63 + } + } + if (src->size() < static_cast(len)) return false; + } + x = LoadBigEndian64(src->data() + len - 8); + } + + x ^= kLengthToMask[len]; // remove spurious header bits + + DCHECK_EQ(len, SignedEncodingLength(x)) << "invalid encoding"; + + if (result) *result = x; + src->remove_prefix(len); + return true; +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/ordered_code.h b/tensorflow/core/lib/strings/ordered_code.h new file mode 100644 index 0000000000..39f1df9a94 --- /dev/null +++ b/tensorflow/core/lib/strings/ordered_code.h @@ -0,0 +1,77 @@ +// This module provides routines for encoding a sequence of typed +// entities into a string. The resulting strings can be +// lexicographically compared to yield the same comparison value that +// would have been generated if the encoded items had been compared +// one by one according to their type. +// +// More precisely, suppose: +// 1. string A is generated by encoding the sequence of items [A_1..A_n] +// 2. string B is generated by encoding the sequence of items [B_1..B_n] +// 3. The types match; i.e., for all i: A_i was encoded using +// the same routine as B_i +// Then: +// Comparing A vs. B lexicographically is the same as comparing +// the vectors [A_1..A_n] and [B_1..B_n] lexicographically. +// +// Furthermore, if n < m, the encoding of [A_1..A_n] is a strict prefix of +// [A_1..A_m] (unless m = n+1 and A_m is the empty string encoded with +// WriteTrailingString, in which case the encodings are equal). +// +// This module is often useful when generating multi-part sstable +// keys that have to be ordered in a particular fashion. + +#ifndef TENSORFLOW_LIB_STRINGS_ORDERED_CODE_H__ +#define TENSORFLOW_LIB_STRINGS_ORDERED_CODE_H__ + +#include +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +class StringPiece; + +namespace strings { + +class OrderedCode { + public: + // ------------------------------------------------------------------- + // Encoding routines: each one of the following routines append + // one item to "*dest" in an encoding where larger values are + // ordered lexicographically after smaller values. + static void WriteString(string* dest, StringPiece str); + static void WriteNumIncreasing(string* dest, uint64 num); + static void WriteSignedNumIncreasing(string* dest, int64 num); + + // ------------------------------------------------------------------- + // Decoding routines: these extract an item earlier encoded using + // the corresponding WriteXXX() routines above. The item is read + // from "*src"; "*src" is modified to point past the decoded item; + // and if "result" is non-NULL, "*result" is modified to contain the + // result. In case of string result, the decoded string is appended to + // "*result". Returns true if the next item was read successfully, false + // otherwise. + static bool ReadString(StringPiece* src, string* result); + static bool ReadNumIncreasing(StringPiece* src, uint64* result); + static bool ReadSignedNumIncreasing(StringPiece* src, int64* result); + + // Helper for testing: corrupt "*str" by changing the kth item separator + // in the string. + static void TEST_Corrupt(string* str, int k); + + // Helper for testing. + // SkipToNextSpecialByte is an internal routine defined in the .cc file + // with the following semantics. Return a pointer to the first byte + // in the range "[start..limit)" whose value is 0 or 255. If no such + // byte exists in the range, returns "limit". + static const char* TEST_SkipToNextSpecialByte(const char* start, + const char* limit); + + private: + // This has only static methods, so disallow construction entirely + OrderedCode(); + TF_DISALLOW_COPY_AND_ASSIGN(OrderedCode); +}; + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_STRINGS_ORDERED_CODE_H__ diff --git a/tensorflow/core/lib/strings/ordered_code_test.cc b/tensorflow/core/lib/strings/ordered_code_test.cc new file mode 100644 index 0000000000..d517d14f4a --- /dev/null +++ b/tensorflow/core/lib/strings/ordered_code_test.cc @@ -0,0 +1,1183 @@ +#include "tensorflow/core/lib/strings/ordered_code.h" + +#include +#include +#include +#include + +#include +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace strings { + +static string RandomString(random::SimplePhilox* rnd, int len) { + string x; + for (int i = 0; i < len; i++) { + x += rnd->Uniform(256); + } + return x; +} + +// --------------------------------------------------------------------- +// Utility template functions (they help templatize the tests below) + +// Read/WriteIncreasing are defined for string, uint64, int64 below. +template +static void OCWriteIncreasing(string* dest, const T& val); +template +static bool OCReadIncreasing(StringPiece* src, T* result); + +// Read/WriteIncreasing +template <> +void OCWriteIncreasing(string* dest, const string& val) { + OrderedCode::WriteString(dest, val); +} +template <> +bool OCReadIncreasing(StringPiece* src, string* result) { + return OrderedCode::ReadString(src, result); +} + +// Read/WriteIncreasing +template <> +void OCWriteIncreasing(string* dest, const uint64& val) { + OrderedCode::WriteNumIncreasing(dest, val); +} +template <> +bool OCReadIncreasing(StringPiece* src, uint64* result) { + return OrderedCode::ReadNumIncreasing(src, result); +} + +// Read/WriteIncreasing +template <> +void OCWriteIncreasing(string* dest, const int64& val) { + OrderedCode::WriteSignedNumIncreasing(dest, val); +} +template <> +bool OCReadIncreasing(StringPiece* src, int64* result) { + return OrderedCode::ReadSignedNumIncreasing(src, result); +} + +template +string OCWrite(T val) { + string result; + OCWriteIncreasing(&result, val); + return result; +} + +template +void OCWriteToString(string* result, T val) { + OCWriteIncreasing(result, val); +} + +template +bool OCRead(StringPiece* s, T* val) { + return OCReadIncreasing(s, val); +} + +// --------------------------------------------------------------------- +// Numbers + +template +static T TestRead(const string& a) { + // gracefully reject any proper prefix of an encoding + for (int i = 0; i < a.size() - 1; ++i) { + StringPiece s(a.data(), i); + CHECK(!OCRead(&s, NULL)); + CHECK_EQ(s, a.substr(0, i)); + } + + StringPiece s(a); + T v; + CHECK(OCRead(&s, &v)); + CHECK(s.empty()); + return v; +} + +template +static void TestWriteRead(T expected) { + EXPECT_EQ(expected, TestRead(OCWrite(expected))); +} + +// Verifies that the second Write* call appends a non-empty string to its +// output. +template +static void TestWriteAppends(T first, U second) { + string encoded; + OCWriteToString(&encoded, first); + string encoded_first_only = encoded; + OCWriteToString(&encoded, second); + EXPECT_NE(encoded, encoded_first_only); + EXPECT_TRUE(StringPiece(encoded).starts_with(encoded_first_only)); +} + +template +static void TestNumbers(T multiplier) { + // first test powers of 2 (and nearby numbers) + for (T x = std::numeric_limits().max(); x != 0; x /= 2) { + TestWriteRead(multiplier * (x - 1)); + TestWriteRead(multiplier * x); + if (x != std::numeric_limits::max()) { + TestWriteRead(multiplier * (x + 1)); + } else if (multiplier < 0 && multiplier == -1) { + TestWriteRead(-x - 1); + } + } + + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int bits = 1; bits <= std::numeric_limits().digits; ++bits) { + // test random non-negative numbers with given number of significant bits + const uint64 mask = (~0ULL) >> (64 - bits); + for (int i = 0; i < 1000; i++) { + T x = rnd.Rand64() & mask; + TestWriteRead(multiplier * x); + T y = rnd.Rand64() & mask; + TestWriteAppends(multiplier * x, multiplier * y); + } + } +} + +// Return true iff 'a' is "before" 'b' +static bool CompareStrings(const string& a, const string& b) { return (a < b); } + +template +static void TestNumberOrdering() { + // first the negative numbers (if T is signed, otherwise no-op) + string laststr = OCWrite(std::numeric_limits().min()); + for (T num = std::numeric_limits().min() / 2; num != 0; num /= 2) { + string strminus1 = OCWrite(num - 1); + string str = OCWrite(num); + string strplus1 = OCWrite(num + 1); + + CHECK(CompareStrings(strminus1, str)); + CHECK(CompareStrings(str, strplus1)); + + // Compare 'str' with 'laststr'. When we approach 0, 'laststr' is + // not necessarily before 'strminus1'. + CHECK(CompareStrings(laststr, str)); + laststr = str; + } + + // then the positive numbers + laststr = OCWrite(0); + T num = 1; + while (num < std::numeric_limits().max() / 2) { + num *= 2; + string strminus1 = OCWrite(num - 1); + string str = OCWrite(num); + string strplus1 = OCWrite(num + 1); + + CHECK(CompareStrings(strminus1, str)); + CHECK(CompareStrings(str, strplus1)); + + // Compare 'str' with 'laststr'. + CHECK(CompareStrings(laststr, str)); + laststr = str; + } +} + +// Helper routine for testing TEST_SkipToNextSpecialByte +static int FindSpecial(const string& x) { + const char* p = x.data(); + const char* limit = p + x.size(); + const char* result = OrderedCode::TEST_SkipToNextSpecialByte(p, limit); + return result - p; +} + +TEST(OrderedCode, SkipToNextSpecialByte) { + for (size_t len = 0; len < 256; len++) { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + string x; + while (x.size() < len) { + char c = 1 + rnd.Uniform(254); + ASSERT_NE(c, 0); + ASSERT_NE(c, 255); + x += c; // No 0 bytes, no 255 bytes + } + EXPECT_EQ(FindSpecial(x), x.size()); + for (size_t special_pos = 0; special_pos < len; special_pos++) { + for (size_t special_test = 0; special_test < 2; special_test++) { + const char special_byte = (special_test == 0) ? 0 : 255; + string y = x; + y[special_pos] = special_byte; + EXPECT_EQ(FindSpecial(y), special_pos); + if (special_pos < 16) { + // Add some special bytes after the one at special_pos to make sure + // we still return the earliest special byte in the string + for (size_t rest = special_pos + 1; rest < len; rest++) { + if (rnd.OneIn(3)) { + y[rest] = rnd.OneIn(2) ? 0 : 255; + EXPECT_EQ(FindSpecial(y), special_pos); + } + } + } + } + } + } +} + +TEST(OrderedCode, ExhaustiveFindSpecial) { + char buf[16]; + char* limit = buf + sizeof(buf); + int count = 0; + for (int start_offset = 0; start_offset <= 5; start_offset += 5) { + // We test exhaustively with all combinations of 3 bytes starting + // at offset 0 and offset 5 (so as to test with the bytes at both + // ends of a 64-bit word). + for (size_t i = 0; i < sizeof(buf); i++) { + buf[i] = 'a'; // Not a special byte + } + for (int b0 = 0; b0 < 256; b0++) { + for (int b1 = 0; b1 < 256; b1++) { + for (int b2 = 0; b2 < 256; b2++) { + buf[start_offset + 0] = b0; + buf[start_offset + 1] = b1; + buf[start_offset + 2] = b2; + char* expected; + if (b0 == 0 || b0 == 255) { + expected = &buf[start_offset]; + } else if (b1 == 0 || b1 == 255) { + expected = &buf[start_offset + 1]; + } else if (b2 == 0 || b2 == 255) { + expected = &buf[start_offset + 2]; + } else { + expected = limit; + } + count++; + EXPECT_EQ(expected, + OrderedCode::TEST_SkipToNextSpecialByte(buf, limit)); + } + } + } + } + EXPECT_EQ(count, 256 * 256 * 256 * 2); +} + +TEST(Uint64, EncodeDecode) { TestNumbers(1); } + +TEST(Uint64, Ordering) { TestNumberOrdering(); } + +TEST(Int64, EncodeDecode) { + TestNumbers(1); + TestNumbers(-1); +} + +TEST(Int64, Ordering) { TestNumberOrdering(); } + +// Returns the bitwise complement of s. +static inline string StrNot(const string& s) { + string result; + for (string::const_iterator it = s.begin(); it != s.end(); ++it) + result.push_back(~*it); + return result; +} + +template +static void TestInvalidEncoding(const string& s) { + StringPiece p(s); + EXPECT_FALSE(OCRead(&p, static_cast(NULL))); + EXPECT_EQ(s, p); +} + +TEST(OrderedCodeInvalidEncodingsTest, Overflow) { + // 1U << 64, increasing and decreasing + const string k2xx64U = "\x09\x01" + string(8, 0); + TestInvalidEncoding(k2xx64U); + + // 1 << 63 and ~(1 << 63), increasing and decreasing + const string k2xx63 = "\xff\xc0\x80" + string(7, 0); + TestInvalidEncoding(k2xx63); + TestInvalidEncoding(StrNot(k2xx63)); +} + +TEST(OrderedCodeInvalidEncodingsDeathTest, NonCanonical) { + // Test "ambiguous"/"non-canonical" encodings. + // These are non-minimal (but otherwise "valid") encodings that + // differ from the minimal encoding chosen by OrderedCode::WriteXXX + // and thus should be avoided to not mess up the string ordering of + // encodings. + + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + + for (int n = 2; n <= 9; ++n) { + // The zero in non_minimal[1] is "redundant". + string non_minimal = + string(1, n - 1) + string(1, 0) + RandomString(&rnd, n - 2); + EXPECT_EQ(n, non_minimal.length()); + + EXPECT_NE(OCWrite(0), non_minimal); +#ifndef NDEBUG + StringPiece s(non_minimal); + EXPECT_DEATH(OrderedCode::ReadNumIncreasing(&s, NULL), "invalid encoding"); +#else + TestRead(non_minimal); +#endif + } + + for (int n = 2; n <= 10; ++n) { + // Header with 1 sign bit and n-1 size bits. + string header = string(n / 8, 0xff) + string(1, 0xff << (8 - (n % 8))); + // There are more than 7 zero bits between header bits and "payload". + string non_minimal = header + + string(1, rnd.Uniform(256) & ~*header.rbegin()) + + RandomString(&rnd, n - header.length() - 1); + EXPECT_EQ(n, non_minimal.length()); + + EXPECT_NE(OCWrite(0), non_minimal); +#ifndef NDEBUG + StringPiece s(non_minimal); + EXPECT_DEATH(OrderedCode::ReadSignedNumIncreasing(&s, NULL), + "invalid encoding") + << n; +#else + TestRead(non_minimal); +#endif + } +} + +// Returns random number with specified number of bits, +// i.e., in the range [2^(bits-1),2^bits). +static uint64 NextBits(random::SimplePhilox* rnd, int bits) { + return (bits != 0) + ? (rnd->Rand64() % (1LL << (bits - 1))) + (1LL << (bits - 1)) + : 0; +} + +template +static void BM_WriteNum(int n, T multiplier) { + static const int kValues = 64; + T values[kValues]; + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + // Use enough distinct values to confuse the branch predictor + for (int i = 0; i < kValues; i++) { + values[i] = NextBits(&rnd, n % 64) * multiplier; + } + string result; + int index = 0; + while (n-- > 0) { + result.clear(); + OCWriteToString(&result, values[index % kValues]); + index++; + } +} + +template +static void BM_ReadNum(int n, T multiplier) { + string x; + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + // Use enough distinct values to confuse the branch predictor + static const int kValues = 64; + string values[kValues]; + for (int i = 0; i < kValues; i++) { + T val = NextBits(&rnd, i % 64) * multiplier; + values[i] = OCWrite(val); + } + uint32 index = 0; + while (n-- > 0) { + T val; + StringPiece s = values[index++ % kValues]; + OCRead(&s, &val); + } +} + +#define BENCHMARK_NUM(name, T, multiplier) \ + static void BM_Write##name(int n) { BM_WriteNum(n, multiplier); } \ + BENCHMARK(BM_Write##name); \ + static void BM_Read##name(int n) { BM_ReadNum(n, multiplier); } \ + BENCHMARK(BM_Read##name) + +BENCHMARK_NUM(NumIncreasing, uint64, 1); +BENCHMARK_NUM(SignedNum, int64, 1); +BENCHMARK_NUM(SignedNumNegative, int64, -1); + +#undef BENCHMARK_NUM + +// --------------------------------------------------------------------- +// Strings + +TEST(String, EncodeDecode) { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + + for (int len = 0; len < 256; len++) { + const string a = RandomString(&rnd, len); + TestWriteRead(a); + for (int len2 = 0; len2 < 64; len2++) { + const string b = RandomString(&rnd, len2); + + TestWriteAppends(a, b); + + string out; + OCWriteToString(&out, a); + OCWriteToString(&out, b); + + string a2, b2, dummy; + StringPiece s = out; + StringPiece s2 = out; + CHECK(OCRead(&s, &a2)); + CHECK(OCRead(&s2, NULL)); + CHECK_EQ(s, s2); + + CHECK(OCRead(&s, &b2)); + CHECK(OCRead(&s2, NULL)); + CHECK_EQ(s, s2); + + CHECK(!OCRead(&s, &dummy)); + CHECK(!OCRead(&s2, NULL)); + CHECK_EQ(a, a2); + CHECK_EQ(b, b2); + CHECK(s.empty()); + CHECK(s2.empty()); + } + } +} + +// 'str' is a static C-style string that may contain '\0' +#define STATIC_STR(str) StringPiece((str), sizeof(str) - 1) + +static string EncodeStringIncreasing(StringPiece value) { + string encoded; + OrderedCode::WriteString(&encoded, value); + return encoded; +} + +TEST(String, Increasing) { + // Here are a series of strings in non-decreasing order, including + // consecutive strings such that the second one is equal to, a proper + // prefix of, or has the same length as the first one. Most also contain + // the special escaping characters '\x00' and '\xff'. + ASSERT_EQ(EncodeStringIncreasing(STATIC_STR("")), + EncodeStringIncreasing(STATIC_STR(""))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("")), + EncodeStringIncreasing(STATIC_STR("\x00"))); + + ASSERT_EQ(EncodeStringIncreasing(STATIC_STR("\x00")), + EncodeStringIncreasing(STATIC_STR("\x00"))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\x00")), + EncodeStringIncreasing(STATIC_STR("\x01"))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\x01")), + EncodeStringIncreasing(STATIC_STR("a"))); + + ASSERT_EQ(EncodeStringIncreasing(STATIC_STR("a")), + EncodeStringIncreasing(STATIC_STR("a"))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("a")), + EncodeStringIncreasing(STATIC_STR("aa"))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("aa")), + EncodeStringIncreasing(STATIC_STR("\xff"))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\xff")), + EncodeStringIncreasing(STATIC_STR("\xff\x00"))); + + ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\xff\x00")), + EncodeStringIncreasing(STATIC_STR("\xff\x01"))); +} + +TEST(EncodingIsExpected, String) { + std::vector> data = { + {"", string("\x00\x01", 2)}, + {"foo", string("foo\x00\x01", 5)}, + {"hello", string("hello\x00\x01", 7)}, + {string("\x00\x01\xff", 3), string("\x00\xff\x01\xff\x00\x00\x01", 7)}, + }; + for (const auto& t : data) { + string result; + OrderedCode::WriteString(&result, t.first); + EXPECT_EQ(t.second, result); + + StringPiece in = result; + string decoded; + EXPECT_TRUE(OrderedCode::ReadString(&in, &decoded)); + EXPECT_EQ(t.first, decoded); + EXPECT_EQ("", in); + } +} + +TEST(EncodingIsExpected, Unsigned) { + std::vector> data = { + {0x0ull, string("\000", 1)}, + {0x1ull, string("\001\001", 2)}, + {0x2ull, string("\001\002", 2)}, + {0x1ull, string("\001\001", 2)}, + {0x2ull, string("\001\002", 2)}, + {0x3ull, string("\001\003", 2)}, + {0x3ull, string("\001\003", 2)}, + {0x4ull, string("\001\004", 2)}, + {0x5ull, string("\001\005", 2)}, + {0x7ull, string("\001\007", 2)}, + {0x8ull, string("\001\010", 2)}, + {0x9ull, string("\001\t", 2)}, + {0xfull, string("\001\017", 2)}, + {0x10ull, string("\001\020", 2)}, + {0x11ull, string("\001\021", 2)}, + {0x1full, string("\001\037", 2)}, + {0x20ull, string("\001 ", 2)}, + {0x21ull, string("\001!", 2)}, + {0x3full, string("\001?", 2)}, + {0x40ull, string("\001@", 2)}, + {0x41ull, string("\001A", 2)}, + {0x7full, string("\001\177", 2)}, + {0x80ull, string("\001\200", 2)}, + {0x81ull, string("\001\201", 2)}, + {0xffull, string("\001\377", 2)}, + {0x100ull, string("\002\001\000", 3)}, + {0x101ull, string("\002\001\001", 3)}, + {0x1ffull, string("\002\001\377", 3)}, + {0x200ull, string("\002\002\000", 3)}, + {0x201ull, string("\002\002\001", 3)}, + {0x3ffull, string("\002\003\377", 3)}, + {0x400ull, string("\002\004\000", 3)}, + {0x401ull, string("\002\004\001", 3)}, + {0x7ffull, string("\002\007\377", 3)}, + {0x800ull, string("\002\010\000", 3)}, + {0x801ull, string("\002\010\001", 3)}, + {0xfffull, string("\002\017\377", 3)}, + {0x1000ull, string("\002\020\000", 3)}, + {0x1001ull, string("\002\020\001", 3)}, + {0x1fffull, string("\002\037\377", 3)}, + {0x2000ull, string("\002 \000", 3)}, + {0x2001ull, string("\002 \001", 3)}, + {0x3fffull, string("\002?\377", 3)}, + {0x4000ull, string("\002@\000", 3)}, + {0x4001ull, string("\002@\001", 3)}, + {0x7fffull, string("\002\177\377", 3)}, + {0x8000ull, string("\002\200\000", 3)}, + {0x8001ull, string("\002\200\001", 3)}, + {0xffffull, string("\002\377\377", 3)}, + {0x10000ull, string("\003\001\000\000", 4)}, + {0x10001ull, string("\003\001\000\001", 4)}, + {0x1ffffull, string("\003\001\377\377", 4)}, + {0x20000ull, string("\003\002\000\000", 4)}, + {0x20001ull, string("\003\002\000\001", 4)}, + {0x3ffffull, string("\003\003\377\377", 4)}, + {0x40000ull, string("\003\004\000\000", 4)}, + {0x40001ull, string("\003\004\000\001", 4)}, + {0x7ffffull, string("\003\007\377\377", 4)}, + {0x80000ull, string("\003\010\000\000", 4)}, + {0x80001ull, string("\003\010\000\001", 4)}, + {0xfffffull, string("\003\017\377\377", 4)}, + {0x100000ull, string("\003\020\000\000", 4)}, + {0x100001ull, string("\003\020\000\001", 4)}, + {0x1fffffull, string("\003\037\377\377", 4)}, + {0x200000ull, string("\003 \000\000", 4)}, + {0x200001ull, string("\003 \000\001", 4)}, + {0x3fffffull, string("\003?\377\377", 4)}, + {0x400000ull, string("\003@\000\000", 4)}, + {0x400001ull, string("\003@\000\001", 4)}, + {0x7fffffull, string("\003\177\377\377", 4)}, + {0x800000ull, string("\003\200\000\000", 4)}, + {0x800001ull, string("\003\200\000\001", 4)}, + {0xffffffull, string("\003\377\377\377", 4)}, + {0x1000000ull, string("\004\001\000\000\000", 5)}, + {0x1000001ull, string("\004\001\000\000\001", 5)}, + {0x1ffffffull, string("\004\001\377\377\377", 5)}, + {0x2000000ull, string("\004\002\000\000\000", 5)}, + {0x2000001ull, string("\004\002\000\000\001", 5)}, + {0x3ffffffull, string("\004\003\377\377\377", 5)}, + {0x4000000ull, string("\004\004\000\000\000", 5)}, + {0x4000001ull, string("\004\004\000\000\001", 5)}, + {0x7ffffffull, string("\004\007\377\377\377", 5)}, + {0x8000000ull, string("\004\010\000\000\000", 5)}, + {0x8000001ull, string("\004\010\000\000\001", 5)}, + {0xfffffffull, string("\004\017\377\377\377", 5)}, + {0x10000000ull, string("\004\020\000\000\000", 5)}, + {0x10000001ull, string("\004\020\000\000\001", 5)}, + {0x1fffffffull, string("\004\037\377\377\377", 5)}, + {0x20000000ull, string("\004 \000\000\000", 5)}, + {0x20000001ull, string("\004 \000\000\001", 5)}, + {0x3fffffffull, string("\004?\377\377\377", 5)}, + {0x40000000ull, string("\004@\000\000\000", 5)}, + {0x40000001ull, string("\004@\000\000\001", 5)}, + {0x7fffffffull, string("\004\177\377\377\377", 5)}, + {0x80000000ull, string("\004\200\000\000\000", 5)}, + {0x80000001ull, string("\004\200\000\000\001", 5)}, + {0xffffffffull, string("\004\377\377\377\377", 5)}, + {0x100000000ull, string("\005\001\000\000\000\000", 6)}, + {0x100000001ull, string("\005\001\000\000\000\001", 6)}, + {0x1ffffffffull, string("\005\001\377\377\377\377", 6)}, + {0x200000000ull, string("\005\002\000\000\000\000", 6)}, + {0x200000001ull, string("\005\002\000\000\000\001", 6)}, + {0x3ffffffffull, string("\005\003\377\377\377\377", 6)}, + {0x400000000ull, string("\005\004\000\000\000\000", 6)}, + {0x400000001ull, string("\005\004\000\000\000\001", 6)}, + {0x7ffffffffull, string("\005\007\377\377\377\377", 6)}, + {0x800000000ull, string("\005\010\000\000\000\000", 6)}, + {0x800000001ull, string("\005\010\000\000\000\001", 6)}, + {0xfffffffffull, string("\005\017\377\377\377\377", 6)}, + {0x1000000000ull, string("\005\020\000\000\000\000", 6)}, + {0x1000000001ull, string("\005\020\000\000\000\001", 6)}, + {0x1fffffffffull, string("\005\037\377\377\377\377", 6)}, + {0x2000000000ull, string("\005 \000\000\000\000", 6)}, + {0x2000000001ull, string("\005 \000\000\000\001", 6)}, + {0x3fffffffffull, string("\005?\377\377\377\377", 6)}, + {0x4000000000ull, string("\005@\000\000\000\000", 6)}, + {0x4000000001ull, string("\005@\000\000\000\001", 6)}, + {0x7fffffffffull, string("\005\177\377\377\377\377", 6)}, + {0x8000000000ull, string("\005\200\000\000\000\000", 6)}, + {0x8000000001ull, string("\005\200\000\000\000\001", 6)}, + {0xffffffffffull, string("\005\377\377\377\377\377", 6)}, + {0x10000000000ull, string("\006\001\000\000\000\000\000", 7)}, + {0x10000000001ull, string("\006\001\000\000\000\000\001", 7)}, + {0x1ffffffffffull, string("\006\001\377\377\377\377\377", 7)}, + {0x20000000000ull, string("\006\002\000\000\000\000\000", 7)}, + {0x20000000001ull, string("\006\002\000\000\000\000\001", 7)}, + {0x3ffffffffffull, string("\006\003\377\377\377\377\377", 7)}, + {0x40000000000ull, string("\006\004\000\000\000\000\000", 7)}, + {0x40000000001ull, string("\006\004\000\000\000\000\001", 7)}, + {0x7ffffffffffull, string("\006\007\377\377\377\377\377", 7)}, + {0x80000000000ull, string("\006\010\000\000\000\000\000", 7)}, + {0x80000000001ull, string("\006\010\000\000\000\000\001", 7)}, + {0xfffffffffffull, string("\006\017\377\377\377\377\377", 7)}, + {0x100000000000ull, string("\006\020\000\000\000\000\000", 7)}, + {0x100000000001ull, string("\006\020\000\000\000\000\001", 7)}, + {0x1fffffffffffull, string("\006\037\377\377\377\377\377", 7)}, + {0x200000000000ull, string("\006 \000\000\000\000\000", 7)}, + {0x200000000001ull, string("\006 \000\000\000\000\001", 7)}, + {0x3fffffffffffull, string("\006?\377\377\377\377\377", 7)}, + {0x400000000000ull, string("\006@\000\000\000\000\000", 7)}, + {0x400000000001ull, string("\006@\000\000\000\000\001", 7)}, + {0x7fffffffffffull, string("\006\177\377\377\377\377\377", 7)}, + {0x800000000000ull, string("\006\200\000\000\000\000\000", 7)}, + {0x800000000001ull, string("\006\200\000\000\000\000\001", 7)}, + {0xffffffffffffull, string("\006\377\377\377\377\377\377", 7)}, + {0x1000000000000ull, string("\007\001\000\000\000\000\000\000", 8)}, + {0x1000000000001ull, string("\007\001\000\000\000\000\000\001", 8)}, + {0x1ffffffffffffull, string("\007\001\377\377\377\377\377\377", 8)}, + {0x2000000000000ull, string("\007\002\000\000\000\000\000\000", 8)}, + {0x2000000000001ull, string("\007\002\000\000\000\000\000\001", 8)}, + {0x3ffffffffffffull, string("\007\003\377\377\377\377\377\377", 8)}, + {0x4000000000000ull, string("\007\004\000\000\000\000\000\000", 8)}, + {0x4000000000001ull, string("\007\004\000\000\000\000\000\001", 8)}, + {0x7ffffffffffffull, string("\007\007\377\377\377\377\377\377", 8)}, + {0x8000000000000ull, string("\007\010\000\000\000\000\000\000", 8)}, + {0x8000000000001ull, string("\007\010\000\000\000\000\000\001", 8)}, + {0xfffffffffffffull, string("\007\017\377\377\377\377\377\377", 8)}, + {0x10000000000000ull, string("\007\020\000\000\000\000\000\000", 8)}, + {0x10000000000001ull, string("\007\020\000\000\000\000\000\001", 8)}, + {0x1fffffffffffffull, string("\007\037\377\377\377\377\377\377", 8)}, + {0x20000000000000ull, string("\007 \000\000\000\000\000\000", 8)}, + {0x20000000000001ull, string("\007 \000\000\000\000\000\001", 8)}, + {0x3fffffffffffffull, string("\007?\377\377\377\377\377\377", 8)}, + {0x40000000000000ull, string("\007@\000\000\000\000\000\000", 8)}, + {0x40000000000001ull, string("\007@\000\000\000\000\000\001", 8)}, + {0x7fffffffffffffull, string("\007\177\377\377\377\377\377\377", 8)}, + {0x80000000000000ull, string("\007\200\000\000\000\000\000\000", 8)}, + {0x80000000000001ull, string("\007\200\000\000\000\000\000\001", 8)}, + {0xffffffffffffffull, string("\007\377\377\377\377\377\377\377", 8)}, + {0x100000000000000ull, string("\010\001\000\000\000\000\000\000\000", 9)}, + {0x100000000000001ull, string("\010\001\000\000\000\000\000\000\001", 9)}, + {0x1ffffffffffffffull, string("\010\001\377\377\377\377\377\377\377", 9)}, + {0x200000000000000ull, string("\010\002\000\000\000\000\000\000\000", 9)}, + {0x200000000000001ull, string("\010\002\000\000\000\000\000\000\001", 9)}, + {0x3ffffffffffffffull, string("\010\003\377\377\377\377\377\377\377", 9)}, + {0x400000000000000ull, string("\010\004\000\000\000\000\000\000\000", 9)}, + {0x400000000000001ull, string("\010\004\000\000\000\000\000\000\001", 9)}, + {0x7ffffffffffffffull, string("\010\007\377\377\377\377\377\377\377", 9)}, + {0x800000000000000ull, string("\010\010\000\000\000\000\000\000\000", 9)}, + {0x800000000000001ull, string("\010\010\000\000\000\000\000\000\001", 9)}, + {0xfffffffffffffffull, string("\010\017\377\377\377\377\377\377\377", 9)}, + {0x1000000000000000ull, + string("\010\020\000\000\000\000\000\000\000", 9)}, + {0x1000000000000001ull, + string("\010\020\000\000\000\000\000\000\001", 9)}, + {0x1fffffffffffffffull, + string("\010\037\377\377\377\377\377\377\377", 9)}, + {0x2000000000000000ull, string("\010 \000\000\000\000\000\000\000", 9)}, + {0x2000000000000001ull, string("\010 \000\000\000\000\000\000\001", 9)}, + {0x3fffffffffffffffull, string("\010?\377\377\377\377\377\377\377", 9)}, + {0x4000000000000000ull, string("\010@\000\000\000\000\000\000\000", 9)}, + {0x4000000000000001ull, string("\010@\000\000\000\000\000\000\001", 9)}, + {0x7fffffffffffffffull, + string("\010\177\377\377\377\377\377\377\377", 9)}, + {0x8000000000000000ull, + string("\010\200\000\000\000\000\000\000\000", 9)}, + {0x8000000000000001ull, + string("\010\200\000\000\000\000\000\000\001", 9)}, + }; + for (const auto& t : data) { + uint64 num = t.first; + string result; + OrderedCode::WriteNumIncreasing(&result, num); + EXPECT_EQ(t.second, result) << std::hex << num; + + StringPiece in = result; + uint64 decoded; + EXPECT_TRUE(OrderedCode::ReadNumIncreasing(&in, &decoded)); + EXPECT_EQ(num, decoded); + EXPECT_EQ("", in); + } +} + +TEST(EncodingIsExpected, Signed) { + std::vector> data = { + {0ll, string("\200", 1)}, + {1ll, string("\201", 1)}, + {2ll, string("\202", 1)}, + {1ll, string("\201", 1)}, + {2ll, string("\202", 1)}, + {3ll, string("\203", 1)}, + {3ll, string("\203", 1)}, + {4ll, string("\204", 1)}, + {5ll, string("\205", 1)}, + {7ll, string("\207", 1)}, + {8ll, string("\210", 1)}, + {9ll, string("\211", 1)}, + {15ll, string("\217", 1)}, + {16ll, string("\220", 1)}, + {17ll, string("\221", 1)}, + {31ll, string("\237", 1)}, + {32ll, string("\240", 1)}, + {33ll, string("\241", 1)}, + {63ll, string("\277", 1)}, + {64ll, string("\300@", 2)}, + {65ll, string("\300A", 2)}, + {127ll, string("\300\177", 2)}, + {128ll, string("\300\200", 2)}, + {129ll, string("\300\201", 2)}, + {255ll, string("\300\377", 2)}, + {256ll, string("\301\000", 2)}, + {257ll, string("\301\001", 2)}, + {511ll, string("\301\377", 2)}, + {512ll, string("\302\000", 2)}, + {513ll, string("\302\001", 2)}, + {1023ll, string("\303\377", 2)}, + {1024ll, string("\304\000", 2)}, + {1025ll, string("\304\001", 2)}, + {2047ll, string("\307\377", 2)}, + {2048ll, string("\310\000", 2)}, + {2049ll, string("\310\001", 2)}, + {4095ll, string("\317\377", 2)}, + {4096ll, string("\320\000", 2)}, + {4097ll, string("\320\001", 2)}, + {8191ll, string("\337\377", 2)}, + {8192ll, string("\340 \000", 3)}, + {8193ll, string("\340 \001", 3)}, + {16383ll, string("\340?\377", 3)}, + {16384ll, string("\340@\000", 3)}, + {16385ll, string("\340@\001", 3)}, + {32767ll, string("\340\177\377", 3)}, + {32768ll, string("\340\200\000", 3)}, + {32769ll, string("\340\200\001", 3)}, + {65535ll, string("\340\377\377", 3)}, + {65536ll, string("\341\000\000", 3)}, + {65537ll, string("\341\000\001", 3)}, + {131071ll, string("\341\377\377", 3)}, + {131072ll, string("\342\000\000", 3)}, + {131073ll, string("\342\000\001", 3)}, + {262143ll, string("\343\377\377", 3)}, + {262144ll, string("\344\000\000", 3)}, + {262145ll, string("\344\000\001", 3)}, + {524287ll, string("\347\377\377", 3)}, + {524288ll, string("\350\000\000", 3)}, + {524289ll, string("\350\000\001", 3)}, + {1048575ll, string("\357\377\377", 3)}, + {1048576ll, string("\360\020\000\000", 4)}, + {1048577ll, string("\360\020\000\001", 4)}, + {2097151ll, string("\360\037\377\377", 4)}, + {2097152ll, string("\360 \000\000", 4)}, + {2097153ll, string("\360 \000\001", 4)}, + {4194303ll, string("\360?\377\377", 4)}, + {4194304ll, string("\360@\000\000", 4)}, + {4194305ll, string("\360@\000\001", 4)}, + {8388607ll, string("\360\177\377\377", 4)}, + {8388608ll, string("\360\200\000\000", 4)}, + {8388609ll, string("\360\200\000\001", 4)}, + {16777215ll, string("\360\377\377\377", 4)}, + {16777216ll, string("\361\000\000\000", 4)}, + {16777217ll, string("\361\000\000\001", 4)}, + {33554431ll, string("\361\377\377\377", 4)}, + {33554432ll, string("\362\000\000\000", 4)}, + {33554433ll, string("\362\000\000\001", 4)}, + {67108863ll, string("\363\377\377\377", 4)}, + {67108864ll, string("\364\000\000\000", 4)}, + {67108865ll, string("\364\000\000\001", 4)}, + {134217727ll, string("\367\377\377\377", 4)}, + {134217728ll, string("\370\010\000\000\000", 5)}, + {134217729ll, string("\370\010\000\000\001", 5)}, + {268435455ll, string("\370\017\377\377\377", 5)}, + {268435456ll, string("\370\020\000\000\000", 5)}, + {268435457ll, string("\370\020\000\000\001", 5)}, + {536870911ll, string("\370\037\377\377\377", 5)}, + {536870912ll, string("\370 \000\000\000", 5)}, + {536870913ll, string("\370 \000\000\001", 5)}, + {1073741823ll, string("\370?\377\377\377", 5)}, + {1073741824ll, string("\370@\000\000\000", 5)}, + {1073741825ll, string("\370@\000\000\001", 5)}, + {2147483647ll, string("\370\177\377\377\377", 5)}, + {2147483648ll, string("\370\200\000\000\000", 5)}, + {2147483649ll, string("\370\200\000\000\001", 5)}, + {4294967295ll, string("\370\377\377\377\377", 5)}, + {4294967296ll, string("\371\000\000\000\000", 5)}, + {4294967297ll, string("\371\000\000\000\001", 5)}, + {8589934591ll, string("\371\377\377\377\377", 5)}, + {8589934592ll, string("\372\000\000\000\000", 5)}, + {8589934593ll, string("\372\000\000\000\001", 5)}, + {17179869183ll, string("\373\377\377\377\377", 5)}, + {17179869184ll, string("\374\004\000\000\000\000", 6)}, + {17179869185ll, string("\374\004\000\000\000\001", 6)}, + {34359738367ll, string("\374\007\377\377\377\377", 6)}, + {34359738368ll, string("\374\010\000\000\000\000", 6)}, + {34359738369ll, string("\374\010\000\000\000\001", 6)}, + {68719476735ll, string("\374\017\377\377\377\377", 6)}, + {68719476736ll, string("\374\020\000\000\000\000", 6)}, + {68719476737ll, string("\374\020\000\000\000\001", 6)}, + {137438953471ll, string("\374\037\377\377\377\377", 6)}, + {137438953472ll, string("\374 \000\000\000\000", 6)}, + {137438953473ll, string("\374 \000\000\000\001", 6)}, + {274877906943ll, string("\374?\377\377\377\377", 6)}, + {274877906944ll, string("\374@\000\000\000\000", 6)}, + {274877906945ll, string("\374@\000\000\000\001", 6)}, + {549755813887ll, string("\374\177\377\377\377\377", 6)}, + {549755813888ll, string("\374\200\000\000\000\000", 6)}, + {549755813889ll, string("\374\200\000\000\000\001", 6)}, + {1099511627775ll, string("\374\377\377\377\377\377", 6)}, + {1099511627776ll, string("\375\000\000\000\000\000", 6)}, + {1099511627777ll, string("\375\000\000\000\000\001", 6)}, + {2199023255551ll, string("\375\377\377\377\377\377", 6)}, + {2199023255552ll, string("\376\002\000\000\000\000\000", 7)}, + {2199023255553ll, string("\376\002\000\000\000\000\001", 7)}, + {4398046511103ll, string("\376\003\377\377\377\377\377", 7)}, + {4398046511104ll, string("\376\004\000\000\000\000\000", 7)}, + {4398046511105ll, string("\376\004\000\000\000\000\001", 7)}, + {8796093022207ll, string("\376\007\377\377\377\377\377", 7)}, + {8796093022208ll, string("\376\010\000\000\000\000\000", 7)}, + {8796093022209ll, string("\376\010\000\000\000\000\001", 7)}, + {17592186044415ll, string("\376\017\377\377\377\377\377", 7)}, + {17592186044416ll, string("\376\020\000\000\000\000\000", 7)}, + {17592186044417ll, string("\376\020\000\000\000\000\001", 7)}, + {35184372088831ll, string("\376\037\377\377\377\377\377", 7)}, + {35184372088832ll, string("\376 \000\000\000\000\000", 7)}, + {35184372088833ll, string("\376 \000\000\000\000\001", 7)}, + {70368744177663ll, string("\376?\377\377\377\377\377", 7)}, + {70368744177664ll, string("\376@\000\000\000\000\000", 7)}, + {70368744177665ll, string("\376@\000\000\000\000\001", 7)}, + {140737488355327ll, string("\376\177\377\377\377\377\377", 7)}, + {140737488355328ll, string("\376\200\000\000\000\000\000", 7)}, + {140737488355329ll, string("\376\200\000\000\000\000\001", 7)}, + {281474976710655ll, string("\376\377\377\377\377\377\377", 7)}, + {281474976710656ll, string("\377\001\000\000\000\000\000\000", 8)}, + {281474976710657ll, string("\377\001\000\000\000\000\000\001", 8)}, + {562949953421311ll, string("\377\001\377\377\377\377\377\377", 8)}, + {562949953421312ll, string("\377\002\000\000\000\000\000\000", 8)}, + {562949953421313ll, string("\377\002\000\000\000\000\000\001", 8)}, + {1125899906842623ll, string("\377\003\377\377\377\377\377\377", 8)}, + {1125899906842624ll, string("\377\004\000\000\000\000\000\000", 8)}, + {1125899906842625ll, string("\377\004\000\000\000\000\000\001", 8)}, + {2251799813685247ll, string("\377\007\377\377\377\377\377\377", 8)}, + {2251799813685248ll, string("\377\010\000\000\000\000\000\000", 8)}, + {2251799813685249ll, string("\377\010\000\000\000\000\000\001", 8)}, + {4503599627370495ll, string("\377\017\377\377\377\377\377\377", 8)}, + {4503599627370496ll, string("\377\020\000\000\000\000\000\000", 8)}, + {4503599627370497ll, string("\377\020\000\000\000\000\000\001", 8)}, + {9007199254740991ll, string("\377\037\377\377\377\377\377\377", 8)}, + {9007199254740992ll, string("\377 \000\000\000\000\000\000", 8)}, + {9007199254740993ll, string("\377 \000\000\000\000\000\001", 8)}, + {18014398509481983ll, string("\377?\377\377\377\377\377\377", 8)}, + {18014398509481984ll, string("\377@\000\000\000\000\000\000", 8)}, + {18014398509481985ll, string("\377@\000\000\000\000\000\001", 8)}, + {36028797018963967ll, string("\377\177\377\377\377\377\377\377", 8)}, + {36028797018963968ll, string("\377\200\200\000\000\000\000\000\000", 9)}, + {36028797018963969ll, string("\377\200\200\000\000\000\000\000\001", 9)}, + {72057594037927935ll, string("\377\200\377\377\377\377\377\377\377", 9)}, + {72057594037927936ll, string("\377\201\000\000\000\000\000\000\000", 9)}, + {72057594037927937ll, string("\377\201\000\000\000\000\000\000\001", 9)}, + {144115188075855871ll, string("\377\201\377\377\377\377\377\377\377", 9)}, + {144115188075855872ll, string("\377\202\000\000\000\000\000\000\000", 9)}, + {144115188075855873ll, string("\377\202\000\000\000\000\000\000\001", 9)}, + {288230376151711743ll, string("\377\203\377\377\377\377\377\377\377", 9)}, + {288230376151711744ll, string("\377\204\000\000\000\000\000\000\000", 9)}, + {288230376151711745ll, string("\377\204\000\000\000\000\000\000\001", 9)}, + {576460752303423487ll, string("\377\207\377\377\377\377\377\377\377", 9)}, + {576460752303423488ll, string("\377\210\000\000\000\000\000\000\000", 9)}, + {576460752303423489ll, string("\377\210\000\000\000\000\000\000\001", 9)}, + {1152921504606846975ll, + string("\377\217\377\377\377\377\377\377\377", 9)}, + {1152921504606846976ll, + string("\377\220\000\000\000\000\000\000\000", 9)}, + {1152921504606846977ll, + string("\377\220\000\000\000\000\000\000\001", 9)}, + {2305843009213693951ll, + string("\377\237\377\377\377\377\377\377\377", 9)}, + {2305843009213693952ll, + string("\377\240\000\000\000\000\000\000\000", 9)}, + {2305843009213693953ll, + string("\377\240\000\000\000\000\000\000\001", 9)}, + {4611686018427387903ll, + string("\377\277\377\377\377\377\377\377\377", 9)}, + {4611686018427387904ll, + string("\377\300@\000\000\000\000\000\000\000", 10)}, + {4611686018427387905ll, + string("\377\300@\000\000\000\000\000\000\001", 10)}, + {9223372036854775807ll, + string("\377\300\177\377\377\377\377\377\377\377", 10)}, + {-9223372036854775807ll, + string("\000?\200\000\000\000\000\000\000\001", 10)}, + {0ll, string("\200", 1)}, + {-1ll, string("\177", 1)}, + {-2ll, string("~", 1)}, + {-1ll, string("\177", 1)}, + {-2ll, string("~", 1)}, + {-3ll, string("}", 1)}, + {-3ll, string("}", 1)}, + {-4ll, string("|", 1)}, + {-5ll, string("{", 1)}, + {-7ll, string("y", 1)}, + {-8ll, string("x", 1)}, + {-9ll, string("w", 1)}, + {-15ll, string("q", 1)}, + {-16ll, string("p", 1)}, + {-17ll, string("o", 1)}, + {-31ll, string("a", 1)}, + {-32ll, string("`", 1)}, + {-33ll, string("_", 1)}, + {-63ll, string("A", 1)}, + {-64ll, string("@", 1)}, + {-65ll, string("?\277", 2)}, + {-127ll, string("?\201", 2)}, + {-128ll, string("?\200", 2)}, + {-129ll, string("?\177", 2)}, + {-255ll, string("?\001", 2)}, + {-256ll, string("?\000", 2)}, + {-257ll, string(">\377", 2)}, + {-511ll, string(">\001", 2)}, + {-512ll, string(">\000", 2)}, + {-513ll, string("=\377", 2)}, + {-1023ll, string("<\001", 2)}, + {-1024ll, string("<\000", 2)}, + {-1025ll, string(";\377", 2)}, + {-2047ll, string("8\001", 2)}, + {-2048ll, string("8\000", 2)}, + {-2049ll, string("7\377", 2)}, + {-4095ll, string("0\001", 2)}, + {-4096ll, string("0\000", 2)}, + {-4097ll, string("/\377", 2)}, + {-8191ll, string(" \001", 2)}, + {-8192ll, string(" \000", 2)}, + {-8193ll, string("\037\337\377", 3)}, + {-16383ll, string("\037\300\001", 3)}, + {-16384ll, string("\037\300\000", 3)}, + {-16385ll, string("\037\277\377", 3)}, + {-32767ll, string("\037\200\001", 3)}, + {-32768ll, string("\037\200\000", 3)}, + {-32769ll, string("\037\177\377", 3)}, + {-65535ll, string("\037\000\001", 3)}, + {-65536ll, string("\037\000\000", 3)}, + {-65537ll, string("\036\377\377", 3)}, + {-131071ll, string("\036\000\001", 3)}, + {-131072ll, string("\036\000\000", 3)}, + {-131073ll, string("\035\377\377", 3)}, + {-262143ll, string("\034\000\001", 3)}, + {-262144ll, string("\034\000\000", 3)}, + {-262145ll, string("\033\377\377", 3)}, + {-524287ll, string("\030\000\001", 3)}, + {-524288ll, string("\030\000\000", 3)}, + {-524289ll, string("\027\377\377", 3)}, + {-1048575ll, string("\020\000\001", 3)}, + {-1048576ll, string("\020\000\000", 3)}, + {-1048577ll, string("\017\357\377\377", 4)}, + {-2097151ll, string("\017\340\000\001", 4)}, + {-2097152ll, string("\017\340\000\000", 4)}, + {-2097153ll, string("\017\337\377\377", 4)}, + {-4194303ll, string("\017\300\000\001", 4)}, + {-4194304ll, string("\017\300\000\000", 4)}, + {-4194305ll, string("\017\277\377\377", 4)}, + {-8388607ll, string("\017\200\000\001", 4)}, + {-8388608ll, string("\017\200\000\000", 4)}, + {-8388609ll, string("\017\177\377\377", 4)}, + {-16777215ll, string("\017\000\000\001", 4)}, + {-16777216ll, string("\017\000\000\000", 4)}, + {-16777217ll, string("\016\377\377\377", 4)}, + {-33554431ll, string("\016\000\000\001", 4)}, + {-33554432ll, string("\016\000\000\000", 4)}, + {-33554433ll, string("\r\377\377\377", 4)}, + {-67108863ll, string("\014\000\000\001", 4)}, + {-67108864ll, string("\014\000\000\000", 4)}, + {-67108865ll, string("\013\377\377\377", 4)}, + {-134217727ll, string("\010\000\000\001", 4)}, + {-134217728ll, string("\010\000\000\000", 4)}, + {-134217729ll, string("\007\367\377\377\377", 5)}, + {-268435455ll, string("\007\360\000\000\001", 5)}, + {-268435456ll, string("\007\360\000\000\000", 5)}, + {-268435457ll, string("\007\357\377\377\377", 5)}, + {-536870911ll, string("\007\340\000\000\001", 5)}, + {-536870912ll, string("\007\340\000\000\000", 5)}, + {-536870913ll, string("\007\337\377\377\377", 5)}, + {-1073741823ll, string("\007\300\000\000\001", 5)}, + {-1073741824ll, string("\007\300\000\000\000", 5)}, + {-1073741825ll, string("\007\277\377\377\377", 5)}, + {-2147483647ll, string("\007\200\000\000\001", 5)}, + {-2147483648ll, string("\007\200\000\000\000", 5)}, + {-2147483649ll, string("\007\177\377\377\377", 5)}, + {-4294967295ll, string("\007\000\000\000\001", 5)}, + {-4294967296ll, string("\007\000\000\000\000", 5)}, + {-4294967297ll, string("\006\377\377\377\377", 5)}, + {-8589934591ll, string("\006\000\000\000\001", 5)}, + {-8589934592ll, string("\006\000\000\000\000", 5)}, + {-8589934593ll, string("\005\377\377\377\377", 5)}, + {-17179869183ll, string("\004\000\000\000\001", 5)}, + {-17179869184ll, string("\004\000\000\000\000", 5)}, + {-17179869185ll, string("\003\373\377\377\377\377", 6)}, + {-34359738367ll, string("\003\370\000\000\000\001", 6)}, + {-34359738368ll, string("\003\370\000\000\000\000", 6)}, + {-34359738369ll, string("\003\367\377\377\377\377", 6)}, + {-68719476735ll, string("\003\360\000\000\000\001", 6)}, + {-68719476736ll, string("\003\360\000\000\000\000", 6)}, + {-68719476737ll, string("\003\357\377\377\377\377", 6)}, + {-137438953471ll, string("\003\340\000\000\000\001", 6)}, + {-137438953472ll, string("\003\340\000\000\000\000", 6)}, + {-137438953473ll, string("\003\337\377\377\377\377", 6)}, + {-274877906943ll, string("\003\300\000\000\000\001", 6)}, + {-274877906944ll, string("\003\300\000\000\000\000", 6)}, + {-274877906945ll, string("\003\277\377\377\377\377", 6)}, + {-549755813887ll, string("\003\200\000\000\000\001", 6)}, + {-549755813888ll, string("\003\200\000\000\000\000", 6)}, + {-549755813889ll, string("\003\177\377\377\377\377", 6)}, + {-1099511627775ll, string("\003\000\000\000\000\001", 6)}, + {-1099511627776ll, string("\003\000\000\000\000\000", 6)}, + {-1099511627777ll, string("\002\377\377\377\377\377", 6)}, + {-2199023255551ll, string("\002\000\000\000\000\001", 6)}, + {-2199023255552ll, string("\002\000\000\000\000\000", 6)}, + {-2199023255553ll, string("\001\375\377\377\377\377\377", 7)}, + {-4398046511103ll, string("\001\374\000\000\000\000\001", 7)}, + {-4398046511104ll, string("\001\374\000\000\000\000\000", 7)}, + {-4398046511105ll, string("\001\373\377\377\377\377\377", 7)}, + {-8796093022207ll, string("\001\370\000\000\000\000\001", 7)}, + {-8796093022208ll, string("\001\370\000\000\000\000\000", 7)}, + {-8796093022209ll, string("\001\367\377\377\377\377\377", 7)}, + {-17592186044415ll, string("\001\360\000\000\000\000\001", 7)}, + {-17592186044416ll, string("\001\360\000\000\000\000\000", 7)}, + {-17592186044417ll, string("\001\357\377\377\377\377\377", 7)}, + {-35184372088831ll, string("\001\340\000\000\000\000\001", 7)}, + {-35184372088832ll, string("\001\340\000\000\000\000\000", 7)}, + {-35184372088833ll, string("\001\337\377\377\377\377\377", 7)}, + {-70368744177663ll, string("\001\300\000\000\000\000\001", 7)}, + {-70368744177664ll, string("\001\300\000\000\000\000\000", 7)}, + {-70368744177665ll, string("\001\277\377\377\377\377\377", 7)}, + {-140737488355327ll, string("\001\200\000\000\000\000\001", 7)}, + {-140737488355328ll, string("\001\200\000\000\000\000\000", 7)}, + {-140737488355329ll, string("\001\177\377\377\377\377\377", 7)}, + {-281474976710655ll, string("\001\000\000\000\000\000\001", 7)}, + {-281474976710656ll, string("\001\000\000\000\000\000\000", 7)}, + {-281474976710657ll, string("\000\376\377\377\377\377\377\377", 8)}, + {-562949953421311ll, string("\000\376\000\000\000\000\000\001", 8)}, + {-562949953421312ll, string("\000\376\000\000\000\000\000\000", 8)}, + {-562949953421313ll, string("\000\375\377\377\377\377\377\377", 8)}, + {-1125899906842623ll, string("\000\374\000\000\000\000\000\001", 8)}, + {-1125899906842624ll, string("\000\374\000\000\000\000\000\000", 8)}, + {-1125899906842625ll, string("\000\373\377\377\377\377\377\377", 8)}, + {-2251799813685247ll, string("\000\370\000\000\000\000\000\001", 8)}, + {-2251799813685248ll, string("\000\370\000\000\000\000\000\000", 8)}, + {-2251799813685249ll, string("\000\367\377\377\377\377\377\377", 8)}, + {-4503599627370495ll, string("\000\360\000\000\000\000\000\001", 8)}, + {-4503599627370496ll, string("\000\360\000\000\000\000\000\000", 8)}, + {-4503599627370497ll, string("\000\357\377\377\377\377\377\377", 8)}, + {-9007199254740991ll, string("\000\340\000\000\000\000\000\001", 8)}, + {-9007199254740992ll, string("\000\340\000\000\000\000\000\000", 8)}, + {-9007199254740993ll, string("\000\337\377\377\377\377\377\377", 8)}, + {-18014398509481983ll, string("\000\300\000\000\000\000\000\001", 8)}, + {-18014398509481984ll, string("\000\300\000\000\000\000\000\000", 8)}, + {-18014398509481985ll, string("\000\277\377\377\377\377\377\377", 8)}, + {-36028797018963967ll, string("\000\200\000\000\000\000\000\001", 8)}, + {-36028797018963968ll, string("\000\200\000\000\000\000\000\000", 8)}, + {-36028797018963969ll, string("\000\177\177\377\377\377\377\377\377", 9)}, + {-72057594037927935ll, string("\000\177\000\000\000\000\000\000\001", 9)}, + {-72057594037927936ll, string("\000\177\000\000\000\000\000\000\000", 9)}, + {-72057594037927937ll, string("\000~\377\377\377\377\377\377\377", 9)}, + {-144115188075855871ll, string("\000~\000\000\000\000\000\000\001", 9)}, + {-144115188075855872ll, string("\000~\000\000\000\000\000\000\000", 9)}, + {-144115188075855873ll, string("\000}\377\377\377\377\377\377\377", 9)}, + {-288230376151711743ll, string("\000|\000\000\000\000\000\000\001", 9)}, + {-288230376151711744ll, string("\000|\000\000\000\000\000\000\000", 9)}, + {-288230376151711745ll, string("\000{\377\377\377\377\377\377\377", 9)}, + {-576460752303423487ll, string("\000x\000\000\000\000\000\000\001", 9)}, + {-576460752303423488ll, string("\000x\000\000\000\000\000\000\000", 9)}, + {-576460752303423489ll, string("\000w\377\377\377\377\377\377\377", 9)}, + {-1152921504606846975ll, string("\000p\000\000\000\000\000\000\001", 9)}, + {-1152921504606846976ll, string("\000p\000\000\000\000\000\000\000", 9)}, + {-1152921504606846977ll, string("\000o\377\377\377\377\377\377\377", 9)}, + {-2305843009213693951ll, string("\000`\000\000\000\000\000\000\001", 9)}, + {-2305843009213693952ll, string("\000`\000\000\000\000\000\000\000", 9)}, + {-2305843009213693953ll, string("\000_\377\377\377\377\377\377\377", 9)}, + {-4611686018427387903ll, string("\000@\000\000\000\000\000\000\001", 9)}, + {-4611686018427387904ll, string("\000@\000\000\000\000\000\000\000", 9)}, + {-4611686018427387905ll, + string("\000?\277\377\377\377\377\377\377\377", 10)}, + {-9223372036854775807ll, + string("\000?\200\000\000\000\000\000\000\001", 10)}, + {9223372036854775807ll, + string("\377\300\177\377\377\377\377\377\377\377", 10)}, + }; + for (const auto& t : data) { + int64 num = t.first; + string result; + OrderedCode::WriteSignedNumIncreasing(&result, num); + EXPECT_EQ(t.second, result) << std::hex << num; + + StringPiece in = result; + int64 decoded; + EXPECT_TRUE(OrderedCode::ReadSignedNumIncreasing(&in, &decoded)); + EXPECT_EQ(num, decoded); + EXPECT_EQ("", in); + } +} + +static void BM_WriteString(int n, int len) { + testing::StopTiming(); + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + string x; + for (int i = 0; i < len; i++) { + x += rnd.Uniform(256); + } + string y; + + testing::BytesProcessed(n * len); + testing::StartTiming(); + while (n-- > 0) { + y.clear(); + OCWriteToString(&y, x); + } +} + +static void BM_ReadString(int n, int len) { + testing::StopTiming(); + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + string x; + for (int i = 0; i < len; i++) { + x += rnd.Uniform(256); + } + string data; + OCWriteToString(&data, x); + string result; + + testing::BytesProcessed(n * len); + testing::StartTiming(); + while (n-- > 0) { + result.clear(); + StringPiece s = data; + OCRead(&s, &result); + } +} + +static void BM_WriteStringIncreasing(int n, int len) { BM_WriteString(n, len); } +static void BM_ReadStringIncreasing(int n, int len) { BM_ReadString(n, len); } + +BENCHMARK(BM_WriteStringIncreasing)->Range(0, 1024); +BENCHMARK(BM_ReadStringIncreasing)->Range(0, 1024); + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc new file mode 100644 index 0000000000..cccd50c7ff --- /dev/null +++ b/tensorflow/core/lib/strings/str_util.cc @@ -0,0 +1,312 @@ +#include "tensorflow/core/lib/strings/str_util.h" +#include + +namespace tensorflow { +namespace str_util { + +static char hex_char[] = "0123456789abcdef"; + +string CEscape(const string& src) { + string dest; + + for (unsigned char c : src) { + switch (c) { + case '\n': + dest.append("\\n"); + break; + case '\r': + dest.append("\\r"); + break; + case '\t': + dest.append("\\t"); + break; + case '\"': + dest.append("\\\""); + break; + case '\'': + dest.append("\\'"); + break; + case '\\': + dest.append("\\\\"); + break; + default: + // Note that if we emit \xNN and the src character after that is a hex + // digit then that digit must be escaped too to prevent it being + // interpreted as part of the character code by C. + if ((c >= 0x80) || !isprint(c)) { + dest.append("\\"); + dest.push_back(hex_char[c / 64]); + dest.push_back(hex_char[(c % 64) / 8]); + dest.push_back(hex_char[c % 8]); + } else { + dest.push_back(c); + break; + } + } + } + + return dest; +} + +namespace { // Private helpers for CUnescape(). + +inline bool is_octal_digit(unsigned char c) { return c >= '0' && c <= '7'; } + +inline bool ascii_isxdigit(unsigned char c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || + (c >= 'A' && c <= 'F'); +} + +inline int hex_digit_to_int(char c) { + int x = static_cast(c); + if (x > '9') { + x += 9; + } + return x & 0xf; +} + +bool CUnescapeInternal(StringPiece source, char* dest, int* dest_len, + string* error) { + char* d = dest; + const char* p = source.data(); + const char* end = source.end(); + const char* last_byte = end - 1; + + // Small optimization for case where source = dest and there's no escaping + while (p == d && p < end && *p != '\\') p++, d++; + + while (p < end) { + if (*p != '\\') { + *d++ = *p++; + } else { + if (++p > last_byte) { // skip past the '\\' + if (error) *error = "String cannot end with \\"; + return false; + } + switch (*p) { + case 'a': + *d++ = '\a'; + break; + case 'b': + *d++ = '\b'; + break; + case 'f': + *d++ = '\f'; + break; + case 'n': + *d++ = '\n'; + break; + case 'r': + *d++ = '\r'; + break; + case 't': + *d++ = '\t'; + break; + case 'v': + *d++ = '\v'; + break; + case '\\': + *d++ = '\\'; + break; + case '?': + *d++ = '\?'; + break; // \? Who knew? + case '\'': + *d++ = '\''; + break; + case '"': + *d++ = '\"'; + break; + case '0': + case '1': + case '2': + case '3': // octal digit: 1 to 3 digits + case '4': + case '5': + case '6': + case '7': { + const char* octal_start = p; + unsigned int ch = *p - '0'; + if (p < last_byte && is_octal_digit(p[1])) ch = ch * 8 + *++p - '0'; + if (p < last_byte && is_octal_digit(p[1])) + ch = ch * 8 + *++p - '0'; // now points at last digit + if (ch > 0xff) { + if (error) { + *error = "Value of \\" + + string(octal_start, p + 1 - octal_start) + + " exceeds 0xff"; + } + return false; + } + *d++ = ch; + break; + } + case 'x': + case 'X': { + if (p >= last_byte) { + if (error) *error = "String cannot end with \\x"; + return false; + } else if (!ascii_isxdigit(p[1])) { + if (error) *error = "\\x cannot be followed by a non-hex digit"; + return false; + } + unsigned int ch = 0; + const char* hex_start = p; + while (p < last_byte && ascii_isxdigit(p[1])) + // Arbitrarily many hex digits + ch = (ch << 4) + hex_digit_to_int(*++p); + if (ch > 0xFF) { + if (error) { + *error = "Value of \\" + string(hex_start, p + 1 - hex_start) + + " exceeds 0xff"; + } + return false; + } + *d++ = ch; + break; + } + default: { + if (error) *error = string("Unknown escape sequence: \\") + *p; + return false; + } + } + p++; // read past letter we escaped + } + } + *dest_len = d - dest; + return true; +} + +} // namespace + +bool CUnescape(StringPiece source, string* dest, string* error) { + dest->resize(source.size()); + int dest_size; + if (!CUnescapeInternal(source, const_cast(dest->data()), &dest_size, + error)) { + return false; + } + dest->erase(dest_size); + return true; +} + +bool NumericParse32(const string& text, int32* val) { + // Slow, but this code is not performance critical, and this + // doesn't bring in any new dependencies + char junk; + if (sscanf(text.c_str(), "%d%c", val, &junk) == 1) { + return true; + } else { + return false; + } +} + +void StripTrailingWhitespace(string* s) { + string::size_type i; + for (i = s->size(); i > 0 && isspace((*s)[i - 1]); --i) { + } + s->resize(i); +} + +// Return lower-cased version of s. +string Lowercase(StringPiece s) { + string result(s.data(), s.size()); + for (char& c : result) { + c = tolower(c); + } + return result; +} + +// Return upper-cased version of s. +string Uppercase(StringPiece s) { + string result(s.data(), s.size()); + for (char& c : result) { + c = toupper(c); + } + return result; +} + +void TitlecaseString(string* s, StringPiece delimiters) { + bool upper = true; + for (string::iterator ss = s->begin(); ss != s->end(); ++ss) { + if (upper) { + *ss = toupper(*ss); + } + upper = (delimiters.find(*ss) != StringPiece::npos); + } +} + +size_t RemoveLeadingWhitespace(StringPiece* text) { + size_t count = 0; + const char* ptr = text->data(); + while (count < text->size() && isspace(*ptr)) { + count++; + ptr++; + } + text->remove_prefix(count); + return count; +} + +size_t RemoveTrailingWhitespace(StringPiece* text) { + size_t count = 0; + const char* ptr = text->data() + text->size() - 1; + while (count < text->size() && isspace(*ptr)) { + ++count; + --ptr; + } + text->remove_suffix(count); + return count; +} + +size_t RemoveWhitespaceContext(StringPiece* text) { + // use RemoveLeadingWhitespace() and RemoveTrailingWhitespace() to do the job + return (RemoveLeadingWhitespace(text) + RemoveTrailingWhitespace(text)); +} + +bool ConsumePrefix(StringPiece* s, StringPiece expected) { + if (s->starts_with(expected)) { + s->remove_prefix(expected.size()); + return true; + } + return false; +} + +bool ConsumeLeadingDigits(StringPiece* s, uint64* val) { + const char* p = s->data(); + const char* limit = p + s->size(); + uint64 v = 0; + while (p < limit) { + const char c = *p; + if (c < '0' || c > '9') break; + uint64 new_v = (v * 10) + (c - '0'); + if (new_v < v) { + // Overflow occurred + return false; + } + v = new_v; + p++; + } + if (p > s->data()) { + // Consume some digits + s->remove_prefix(p - s->data()); + *val = v; + return true; + } else { + return false; + } +} + +bool SplitAndParseAsInts(StringPiece text, char delim, + std::vector* result) { + result->clear(); + std::vector num_strings = Split(text, delim); + for (const auto& s : num_strings) { + int32 num; + if (!NumericParse32(s, &num)) return false; + result->push_back(num); + } + return true; +} + +} // namespace str_util +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h new file mode 100644 index 0000000000..34ea462b2d --- /dev/null +++ b/tensorflow/core/lib/strings/str_util.h @@ -0,0 +1,149 @@ +#ifndef TENSORFLOW_LIB_STRINGS_STR_UTIL_H_ +#define TENSORFLOW_LIB_STRINGS_STR_UTIL_H_ + +#include +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" + +// Basic string utility routines +namespace tensorflow { +namespace str_util { + +// Returns a version of 'src' where unprintable characters have been +// escaped using C-style escape sequences. +string CEscape(const string& src); + +// Copies "source" to "dest", rewriting C-style escape sequences -- +// '\n', '\r', '\\', '\ooo', etc -- to their ASCII equivalents. +// +// Errors: Sets the description of the first encountered error in +// 'error'. To disable error reporting, set 'error' to NULL. +// +// NOTE: Does not support \u or \U! +bool CUnescape(StringPiece source, string* dest, string* error); + +// If "text" can be successfully parsed as the ASCII representation of +// an integer, sets "*val" to the value and returns true. Otherwise, +// returns false. +bool NumericParse32(const string& text, int32* val); + +// Removes any trailing whitespace from "*s". +void StripTrailingWhitespace(string* s); + +// Removes leading ascii_isspace() characters. +// Returns number of characters removed. +size_t RemoveLeadingWhitespace(StringPiece* text); + +// Removes trailing ascii_isspace() characters. +// Returns number of characters removed. +size_t RemoveTrailingWhitespace(StringPiece* text); + +// Removes leading and trailing ascii_isspace() chars. +// Returns number of chars removed. +size_t RemoveWhitespaceContext(StringPiece* text); + +// Consume a leading positive integer value. If any digits were +// found, store the value of the leading unsigned number in "*val", +// advance "*s" past the consumed number, and return true. If +// overflow occurred, returns false. Otherwise, returns false. +bool ConsumeLeadingDigits(StringPiece* s, uint64* val); + +// If "*s" starts with "expected", consume it and return true. +// Otherwise, return false. +bool ConsumePrefix(StringPiece* s, StringPiece expected); + +// Return lower-cased version of s. +string Lowercase(StringPiece s); + +// Return upper-cased version of s. +string Uppercase(StringPiece s); + +// Capitalize first character of each word in "*s". "delimiters" is a +// set of characters that can be used as word boundaries. +void TitlecaseString(string* s, StringPiece delimiters); + +// Join functionality +template +string Join(const std::vector& s, const char* sep); +template +string Join(const gtl::ArraySlice& s, const char* sep); + +struct AllowEmpty { + bool operator()(StringPiece sp) const { return true; } +}; +struct SkipEmpty { + bool operator()(StringPiece sp) const { return !sp.empty(); } +}; +struct SkipWhitespace { + bool operator()(StringPiece sp) const { + RemoveTrailingWhitespace(&sp); + return !sp.empty(); + } +}; + +std::vector Split(StringPiece text, char delim); +template +std::vector Split(StringPiece text, char delim, Predicate p); + +// Split "text" at "delim" characters, and parse each component as +// an integer. If successful, adds the individual numbers in order +// to "*result" and returns true. Otherwise returns false. +bool SplitAndParseAsInts(StringPiece text, char delim, + std::vector* result); + +// ------------------------------------------------------------------ +// Implementation details below +namespace internal { +template +string JoinHelper(typename gtl::ArraySlice::const_iterator begin, + typename gtl::ArraySlice::const_iterator end, + const char* sep) { + string result; + bool first = true; + for (typename gtl::ArraySlice::const_iterator it = begin; it != end; + ++it) { + tensorflow::strings::StrAppend(&result, (first ? "" : sep), *it); + first = false; + } + return result; +} +} // namespace internal + +template +string Join(const std::vector& s, const char* sep) { + return Join(gtl::ArraySlice(s), sep); +} + +template +string Join(const gtl::ArraySlice& s, const char* sep) { + return internal::JoinHelper(s.begin(), s.end(), sep); +} + +inline std::vector Split(StringPiece text, char delim) { + return Split(text, delim, AllowEmpty()); +} + +template +std::vector Split(StringPiece text, char delim, Predicate p) { + std::vector result; + int token_start = 0; + if (!text.empty()) { + for (int i = 0; i < text.size() + 1; i++) { + if ((i == text.size()) || (text[i] == delim)) { + StringPiece token(text.data() + token_start, i - token_start); + if (p(token)) { + result.push_back(token.ToString()); + } + token_start = i + 1; + } + } + } + return result; +} + +} // namespace str_util +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_STRINGS_STR_UTIL_H_ diff --git a/tensorflow/core/lib/strings/str_util_test.cc b/tensorflow/core/lib/strings/str_util_test.cc new file mode 100644 index 0000000000..f71cc6c609 --- /dev/null +++ b/tensorflow/core/lib/strings/str_util_test.cc @@ -0,0 +1,258 @@ +#include "tensorflow/core/lib/strings/str_util.h" + +#include + +namespace tensorflow { + +TEST(CEscape, Basic) { + EXPECT_EQ(str_util::CEscape("hello"), "hello"); + EXPECT_EQ(str_util::CEscape("hello\n"), "hello\\n"); + EXPECT_EQ(str_util::CEscape("hello\r"), "hello\\r"); + EXPECT_EQ(str_util::CEscape("\t\r\"'"), "\\t\\r\\\"\\'"); + EXPECT_EQ(str_util::CEscape("\320hi\200"), "\\320hi\\200"); +} + +string ExpectCUnescapeSuccess(StringPiece source) { + string dest; + string error; + EXPECT_TRUE(str_util::CUnescape(source, &dest, &error)) << error; + return dest; +} + +TEST(CUnescape, Basic) { + EXPECT_EQ("hello", ExpectCUnescapeSuccess("hello")); + EXPECT_EQ("hello\n", ExpectCUnescapeSuccess("hello\\n")); + EXPECT_EQ("hello\r", ExpectCUnescapeSuccess("hello\\r")); + EXPECT_EQ("\t\r\"'", ExpectCUnescapeSuccess("\\t\\r\\\"\\'")); + EXPECT_EQ("\320hi\200", ExpectCUnescapeSuccess("\\320hi\\200")); +} + +TEST(NumericParse32, Basic) { + int32 val = -1234; + EXPECT_TRUE(str_util::NumericParse32("0", &val) && val == 0); + EXPECT_TRUE(str_util::NumericParse32("123", &val) && val == 123); + EXPECT_TRUE(str_util::NumericParse32("-375", &val) && val == -375); + EXPECT_FALSE(str_util::NumericParse32("123hello", &val)); + EXPECT_FALSE(str_util::NumericParse32("hello123", &val)); +} + +TEST(StripTrailingWhitespace, Basic) { + string test; + test = "hello"; + str_util::StripTrailingWhitespace(&test); + EXPECT_EQ(test, "hello"); + + test = "foo "; + str_util::StripTrailingWhitespace(&test); + EXPECT_EQ(test, "foo"); + + test = " "; + str_util::StripTrailingWhitespace(&test); + EXPECT_EQ(test, ""); + + test = ""; + str_util::StripTrailingWhitespace(&test); + EXPECT_EQ(test, ""); + + test = " abc\t"; + str_util::StripTrailingWhitespace(&test); + EXPECT_EQ(test, " abc"); +} + +TEST(RemoveLeadingWhitespace, Basic) { + string text = " \t \n \r Quick\t"; + StringPiece data(text); + // check that all whitespace is removed + EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 11); + EXPECT_EQ(data, StringPiece("Quick\t")); + // check that non-whitespace is not removed + EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 0); + EXPECT_EQ(data, StringPiece("Quick\t")); +} + +TEST(RemoveLeadingWhitespace, TerminationHandling) { + // check termination handling + string text = "\t"; + StringPiece data(text); + EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 1); + EXPECT_EQ(data, StringPiece("")); + + // check termination handling again + EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 0); + EXPECT_EQ(data, StringPiece("")); +} + +TEST(RemoveTrailingWhitespace, Basic) { + string text = " \t \n \r Quick \t"; + StringPiece data(text); + // check that all whitespace is removed + EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 2); + EXPECT_EQ(data, StringPiece(" \t \n \r Quick")); + // check that non-whitespace is not removed + EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 0); + EXPECT_EQ(data, StringPiece(" \t \n \r Quick")); +} + +TEST(RemoveTrailingWhitespace, TerminationHandling) { + // check termination handling + string text = "\t"; + StringPiece data(text); + EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 1); + EXPECT_EQ(data, StringPiece("")); + + // check termination handling again + EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 0); + EXPECT_EQ(data, StringPiece("")); +} + +TEST(RemoveWhitespaceContext, Basic) { + string text = " \t \n \r Quick \t"; + StringPiece data(text); + // check that all whitespace is removed + EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 13); + EXPECT_EQ(data, StringPiece("Quick")); + // check that non-whitespace is not removed + EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 0); + EXPECT_EQ(data, StringPiece("Quick")); + + // Test empty string + text = ""; + data = text; + EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 0); + EXPECT_EQ(data, StringPiece("")); +} + +void TestConsumeLeadingDigits(StringPiece s, int64 expected, + StringPiece remaining) { + uint64 v; + StringPiece input(s); + if (str_util::ConsumeLeadingDigits(&input, &v)) { + EXPECT_EQ(v, static_cast(expected)); + EXPECT_EQ(input, remaining); + } else { + EXPECT_LT(expected, 0); + EXPECT_EQ(input, remaining); + } +} + +TEST(ConsumeLeadingDigits, Basic) { + TestConsumeLeadingDigits("123", 123, ""); + TestConsumeLeadingDigits("a123", -1, "a123"); + TestConsumeLeadingDigits("9_", 9, "_"); + TestConsumeLeadingDigits("11111111111xyz", 11111111111ll, "xyz"); + + // Overflow case + TestConsumeLeadingDigits("1111111111111111111111111111111xyz", -1, + "1111111111111111111111111111111xyz"); + + // 2^64 + TestConsumeLeadingDigits("18446744073709551616xyz", -1, + "18446744073709551616xyz"); + // 2^64-1 + TestConsumeLeadingDigits("18446744073709551615xyz", 18446744073709551615ull, + "xyz"); +} + +TEST(ConsumePrefix, Basic) { + string s("abcdef"); + StringPiece input(s); + EXPECT_FALSE(str_util::ConsumePrefix(&input, "abcdefg")); + EXPECT_EQ(input, "abcdef"); + + EXPECT_FALSE(str_util::ConsumePrefix(&input, "abce")); + EXPECT_EQ(input, "abcdef"); + + EXPECT_TRUE(str_util::ConsumePrefix(&input, "")); + EXPECT_EQ(input, "abcdef"); + + EXPECT_FALSE(str_util::ConsumePrefix(&input, "abcdeg")); + EXPECT_EQ(input, "abcdef"); + + EXPECT_TRUE(str_util::ConsumePrefix(&input, "abcdef")); + EXPECT_EQ(input, ""); + + input = s; + EXPECT_TRUE(str_util::ConsumePrefix(&input, "abcde")); + EXPECT_EQ(input, "f"); +} + +TEST(JoinStrings, Basic) { + std::vector s; + s = {"hi"}; + EXPECT_EQ(str_util::Join(s, " "), "hi"); + s = {"hi", "there", "strings"}; + EXPECT_EQ(str_util::Join(s, " "), "hi there strings"); + + std::vector sp; + sp = {"hi"}; + EXPECT_EQ(str_util::Join(sp, ",,"), "hi"); + sp = {"hi", "there", "strings"}; + EXPECT_EQ(str_util::Join(sp, "--"), "hi--there--strings"); +} + +TEST(Split, Basic) { + EXPECT_TRUE(str_util::Split("", ',').empty()); + EXPECT_EQ(str_util::Join(str_util::Split("a", ','), "|"), "a"); + EXPECT_EQ(str_util::Join(str_util::Split(",", ','), "|"), "|"); + EXPECT_EQ(str_util::Join(str_util::Split("a,b,c", ','), "|"), "a|b|c"); + EXPECT_EQ(str_util::Join(str_util::Split("a,,,b,,c,", ','), "|"), + "a|||b||c|"); + EXPECT_EQ(str_util::Join( + str_util::Split("a,,,b,,c,", ',', str_util::SkipEmpty()), "|"), + "a|b|c"); + EXPECT_EQ( + str_util::Join( + str_util::Split("a, ,b,,c,", ',', str_util::SkipWhitespace()), "|"), + "a|b|c"); +} + +TEST(SplitAndParseAsInts, Basic) { + std::vector nums; + EXPECT_TRUE(str_util::SplitAndParseAsInts("", ',', &nums)); + EXPECT_EQ(nums.size(), 0); + + EXPECT_TRUE(str_util::SplitAndParseAsInts("134", ',', &nums)); + EXPECT_EQ(nums.size(), 1); + EXPECT_EQ(nums[0], 134); + + EXPECT_TRUE(str_util::SplitAndParseAsInts("134,2,13,-5", ',', &nums)); + EXPECT_EQ(nums.size(), 4); + EXPECT_EQ(nums[0], 134); + EXPECT_EQ(nums[1], 2); + EXPECT_EQ(nums[2], 13); + EXPECT_EQ(nums[3], -5); + + EXPECT_FALSE(str_util::SplitAndParseAsInts("abc", ',', &nums)); + + EXPECT_FALSE(str_util::SplitAndParseAsInts("-13,abc", ',', &nums)); + + EXPECT_FALSE(str_util::SplitAndParseAsInts("13,abc,5", ',', &nums)); +} + +TEST(Lowercase, Basic) { + EXPECT_EQ("", str_util::Lowercase("")); + EXPECT_EQ("hello", str_util::Lowercase("hello")); + EXPECT_EQ("hello world", str_util::Lowercase("Hello World")); +} + +TEST(Uppercase, Basic) { + EXPECT_EQ("", str_util::Uppercase("")); + EXPECT_EQ("HELLO", str_util::Uppercase("hello")); + EXPECT_EQ("HELLO WORLD", str_util::Uppercase("Hello World")); +} + +TEST(TitlecaseString, Basic) { + string s = "sparse_lookup"; + str_util::TitlecaseString(&s, "_"); + ASSERT_EQ(s, "Sparse_Lookup"); + + s = "sparse_lookup"; + str_util::TitlecaseString(&s, " "); + ASSERT_EQ(s, "Sparse_lookup"); + + s = "dense"; + str_util::TitlecaseString(&s, " "); + ASSERT_EQ(s, "Dense"); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/strcat.cc b/tensorflow/core/lib/strings/strcat.cc new file mode 100644 index 0000000000..e564b9eb73 --- /dev/null +++ b/tensorflow/core/lib/strings/strcat.cc @@ -0,0 +1,194 @@ +#include "tensorflow/core/lib/strings/strcat.h" + +#include +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/gtl/stl_util.h" + +namespace tensorflow { +namespace strings { + +AlphaNum gEmptyAlphaNum(""); + +AlphaNum::AlphaNum(Hex hex) { + char *const end = &digits_[kFastToBufferSize]; + char *writer = end; + uint64 value = hex.value; + uint64 width = hex.spec; + // We accomplish minimum width by OR'ing in 0x10000 to the user's value, + // where 0x10000 is the smallest hex number that is as wide as the user + // asked for. + uint64 mask = ((static_cast(1) << (width - 1) * 4)) | value; + static const char hexdigits[] = "0123456789abcdef"; + do { + *--writer = hexdigits[value & 0xF]; + value >>= 4; + mask >>= 4; + } while (mask != 0); + piece_.set(writer, end - writer); +} + +// ---------------------------------------------------------------------- +// StrCat() +// This merges the given strings or integers, with no delimiter. This +// is designed to be the fastest possible way to construct a string out +// of a mix of raw C strings, StringPieces, strings, and integer values. +// ---------------------------------------------------------------------- + +// Append is merely a version of memcpy that returns the address of the byte +// after the area just overwritten. It comes in multiple flavors to minimize +// call overhead. +static char *Append1(char *out, const AlphaNum &x) { + memcpy(out, x.data(), x.size()); + return out + x.size(); +} + +static char *Append2(char *out, const AlphaNum &x1, const AlphaNum &x2) { + memcpy(out, x1.data(), x1.size()); + out += x1.size(); + + memcpy(out, x2.data(), x2.size()); + return out + x2.size(); +} + +static char *Append4(char *out, const AlphaNum &x1, const AlphaNum &x2, + const AlphaNum &x3, const AlphaNum &x4) { + memcpy(out, x1.data(), x1.size()); + out += x1.size(); + + memcpy(out, x2.data(), x2.size()); + out += x2.size(); + + memcpy(out, x3.data(), x3.size()); + out += x3.size(); + + memcpy(out, x4.data(), x4.size()); + return out + x4.size(); +} + +string StrCat(const AlphaNum &a, const AlphaNum &b) { + string result; + gtl::STLStringResizeUninitialized(&result, a.size() + b.size()); + char *const begin = &*result.begin(); + char *out = Append2(begin, a, b); + DCHECK_EQ(out, begin + result.size()); + return result; +} + +string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c) { + string result; + gtl::STLStringResizeUninitialized(&result, a.size() + b.size() + c.size()); + char *const begin = &*result.begin(); + char *out = Append2(begin, a, b); + out = Append1(out, c); + DCHECK_EQ(out, begin + result.size()); + return result; +} + +string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, + const AlphaNum &d) { + string result; + gtl::STLStringResizeUninitialized(&result, + a.size() + b.size() + c.size() + d.size()); + char *const begin = &*result.begin(); + char *out = Append4(begin, a, b, c, d); + DCHECK_EQ(out, begin + result.size()); + return result; +} + +namespace internal { + +// Do not call directly - these are not part of the public API. +string CatPieces(std::initializer_list pieces) { + string result; + size_t total_size = 0; + for (const StringPiece piece : pieces) total_size += piece.size(); + gtl::STLStringResizeUninitialized(&result, total_size); + + char *const begin = &*result.begin(); + char *out = begin; + for (const StringPiece piece : pieces) { + const size_t this_size = piece.size(); + memcpy(out, piece.data(), this_size); + out += this_size; + } + DCHECK_EQ(out, begin + result.size()); + return result; +} + +// It's possible to call StrAppend with a StringPiece that is itself a fragment +// of the string we're appending to. However the results of this are random. +// Therefore, check for this in debug mode. Use unsigned math so we only have +// to do one comparison. +#define DCHECK_NO_OVERLAP(dest, src) \ + DCHECK_GE(uintptr_t((src).data() - (dest).data()), uintptr_t((dest).size())) + +void AppendPieces(string *result, std::initializer_list pieces) { + size_t old_size = result->size(); + size_t total_size = old_size; + for (const StringPiece piece : pieces) { + DCHECK_NO_OVERLAP(*result, piece); + total_size += piece.size(); + } + gtl::STLStringResizeUninitialized(result, total_size); + + char *const begin = &*result->begin(); + char *out = begin + old_size; + for (const StringPiece piece : pieces) { + const size_t this_size = piece.size(); + memcpy(out, piece.data(), this_size); + out += this_size; + } + DCHECK_EQ(out, begin + result->size()); +} + +} // namespace internal + +void StrAppend(string *result, const AlphaNum &a) { + DCHECK_NO_OVERLAP(*result, a); + result->append(a.data(), a.size()); +} + +void StrAppend(string *result, const AlphaNum &a, const AlphaNum &b) { + DCHECK_NO_OVERLAP(*result, a); + DCHECK_NO_OVERLAP(*result, b); + string::size_type old_size = result->size(); + gtl::STLStringResizeUninitialized(result, old_size + a.size() + b.size()); + char *const begin = &*result->begin(); + char *out = Append2(begin + old_size, a, b); + DCHECK_EQ(out, begin + result->size()); +} + +void StrAppend(string *result, const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c) { + DCHECK_NO_OVERLAP(*result, a); + DCHECK_NO_OVERLAP(*result, b); + DCHECK_NO_OVERLAP(*result, c); + string::size_type old_size = result->size(); + gtl::STLStringResizeUninitialized(result, + old_size + a.size() + b.size() + c.size()); + char *const begin = &*result->begin(); + char *out = Append2(begin + old_size, a, b); + out = Append1(out, c); + DCHECK_EQ(out, begin + result->size()); +} + +void StrAppend(string *result, const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c, const AlphaNum &d) { + DCHECK_NO_OVERLAP(*result, a); + DCHECK_NO_OVERLAP(*result, b); + DCHECK_NO_OVERLAP(*result, c); + DCHECK_NO_OVERLAP(*result, d); + string::size_type old_size = result->size(); + gtl::STLStringResizeUninitialized( + result, old_size + a.size() + b.size() + c.size() + d.size()); + char *const begin = &*result->begin(); + char *out = Append4(begin + old_size, a, b, c, d); + DCHECK_EQ(out, begin + result->size()); +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h new file mode 100644 index 0000000000..763ad8368a --- /dev/null +++ b/tensorflow/core/lib/strings/strcat.h @@ -0,0 +1,229 @@ +// #status: RECOMMENDED +// #category: operations on strings +// #summary: Merges strings or numbers with no delimiter. +// +#ifndef TENSORFLOW_LIB_STRINGS_STRCAT_H_ +#define TENSORFLOW_LIB_STRINGS_STRCAT_H_ + +#include + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/port.h" + +// The AlphaNum type was designed to be used as the parameter type for StrCat(). +// Any routine accepting either a string or a number may accept it. +// The basic idea is that by accepting a "const AlphaNum &" as an argument +// to your function, your callers will automagically convert bools, integers, +// and floating point values to strings for you. +// +// NOTE: Use of AlphaNum outside of the //strings package is unsupported except +// for the specific case of function parameters of type "AlphaNum" or "const +// AlphaNum &". In particular, instantiating AlphaNum directly as a stack +// variable is not supported. +// +// Conversion from 8-bit values is not accepted because if it were, then an +// attempt to pass ':' instead of ":" might result in a 58 ending up in your +// result. +// +// Bools convert to "0" or "1". +// +// Floating point values are converted to a string which, if passed to strtod(), +// would produce the exact same original double (except in case of NaN; all NaNs +// are considered the same value). We try to keep the string short but it's not +// guaranteed to be as short as possible. +// +// You can convert to Hexadecimal output rather than Decimal output using Hex. +// To do this, pass strings::Hex(my_int) as a parameter to StrCat. You may +// specify a minimum field width using a separate parameter, so the equivalent +// of Printf("%04x", my_int) is StrCat(Hex(my_int, strings::ZERO_PAD_4)) +// +// This class has implicit constructors. +namespace tensorflow { +namespace strings { + +enum PadSpec { + NO_PAD = 1, + ZERO_PAD_2, + ZERO_PAD_3, + ZERO_PAD_4, + ZERO_PAD_5, + ZERO_PAD_6, + ZERO_PAD_7, + ZERO_PAD_8, + ZERO_PAD_9, + ZERO_PAD_10, + ZERO_PAD_11, + ZERO_PAD_12, + ZERO_PAD_13, + ZERO_PAD_14, + ZERO_PAD_15, + ZERO_PAD_16, +}; + +struct Hex { + uint64 value; + enum PadSpec spec; + template + explicit Hex(Int v, PadSpec s = NO_PAD) + : spec(s) { + // Prevent sign-extension by casting integers to + // their unsigned counterparts. + static_assert( + sizeof(v) == 1 || sizeof(v) == 2 || sizeof(v) == 4 || sizeof(v) == 8, + "Unknown integer type"); + value = sizeof(v) == 1 + ? static_cast(v) + : sizeof(v) == 2 ? static_cast(v) + : sizeof(v) == 4 ? static_cast(v) + : static_cast(v); + } +}; + +class AlphaNum { + public: + // No bool ctor -- bools convert to an integral type. + // A bool ctor would also convert incoming pointers (bletch). + + AlphaNum(int i32) // NOLINT(runtime/explicit) + : piece_(digits_, FastInt32ToBufferLeft(i32, digits_) - &digits_[0]) {} + AlphaNum(unsigned int u32) // NOLINT(runtime/explicit) + : piece_(digits_, FastUInt32ToBufferLeft(u32, digits_) - &digits_[0]) {} + AlphaNum(long x) // NOLINT(runtime/explicit) + : piece_(digits_, FastInt64ToBufferLeft(x, digits_) - &digits_[0]) {} + AlphaNum(unsigned long x) // NOLINT(runtime/explicit) + : piece_(digits_, FastUInt64ToBufferLeft(x, digits_) - &digits_[0]) {} + AlphaNum(long long int i64) // NOLINT(runtime/explicit) + : piece_(digits_, FastInt64ToBufferLeft(i64, digits_) - &digits_[0]) {} + AlphaNum(unsigned long long int u64) // NOLINT(runtime/explicit) + : piece_(digits_, FastUInt64ToBufferLeft(u64, digits_) - &digits_[0]) {} + + AlphaNum(float f) // NOLINT(runtime/explicit) + : piece_(digits_, strlen(FloatToBuffer(f, digits_))) {} + AlphaNum(double f) // NOLINT(runtime/explicit) + : piece_(digits_, strlen(DoubleToBuffer(f, digits_))) {} + + AlphaNum(Hex hex); // NOLINT(runtime/explicit) + + AlphaNum(const char *c_str) : piece_(c_str) {} // NOLINT(runtime/explicit) + AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit) + AlphaNum(const tensorflow::string &str) // NOLINT(runtime/explicit) + : piece_(str) {} + + StringPiece::size_type size() const { return piece_.size(); } + const char *data() const { return piece_.data(); } + StringPiece Piece() const { return piece_; } + + private: + StringPiece piece_; + char digits_[kFastToBufferSize]; + + // Use ":" not ':' + AlphaNum(char c); // NOLINT(runtime/explicit) + + TF_DISALLOW_COPY_AND_ASSIGN(AlphaNum); +}; + +extern AlphaNum gEmptyAlphaNum; + +using strings::AlphaNum; +using strings::gEmptyAlphaNum; + +// ---------------------------------------------------------------------- +// StrCat() +// This merges the given strings or numbers, with no delimiter. This +// is designed to be the fastest possible way to construct a string out +// of a mix of raw C strings, StringPieces, strings, bool values, +// and numeric values. +// +// Don't use this for user-visible strings. The localization process +// works poorly on strings built up out of fragments. +// +// For clarity and performance, don't use StrCat when appending to a +// string. In particular, avoid using any of these (anti-)patterns: +// str.append(StrCat(...)) +// str += StrCat(...) +// str = StrCat(str, ...) +// where the last is the worse, with the potential to change a loop +// from a linear time operation with O(1) dynamic allocations into a +// quadratic time operation with O(n) dynamic allocations. StrAppend +// is a better choice than any of the above, subject to the restriction +// of StrAppend(&str, a, b, c, ...) that none of the a, b, c, ... may +// be a reference into str. +// ---------------------------------------------------------------------- + +// For performance reasons, we have specializations for <= 4 args. +string StrCat(const AlphaNum &a) TF_MUST_USE_RESULT; +string StrCat(const AlphaNum &a, const AlphaNum &b) TF_MUST_USE_RESULT; +string StrCat(const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c) TF_MUST_USE_RESULT; +string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, + const AlphaNum &d) TF_MUST_USE_RESULT; + +// inline definitions must be duplicated due to TF_MUST_USE_RESULT +inline string StrCat(const AlphaNum &a) { return string(a.data(), a.size()); } + +namespace internal { + +// Do not call directly - this is not part of the public API. +string CatPieces(std::initializer_list pieces); +void AppendPieces(string *dest, std::initializer_list pieces); + +} // namespace internal + +// Support 5 or more arguments +template +string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, + const AlphaNum &d, const AlphaNum &e, + const AV &... args) TF_MUST_USE_RESULT; + +template +inline string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, + const AlphaNum &d, const AlphaNum &e, const AV &... args) { + return internal::CatPieces({a.Piece(), b.Piece(), c.Piece(), d.Piece(), + e.Piece(), + static_cast(args).Piece()...}); +} + +// ---------------------------------------------------------------------- +// StrAppend() +// Same as above, but adds the output to the given string. +// WARNING: For speed, StrAppend does not try to check each of its input +// arguments to be sure that they are not a subset of the string being +// appended to. That is, while this will work: +// +// string s = "foo"; +// s += s; +// +// This will not (necessarily) work: +// +// string s = "foo"; +// StrAppend(&s, s); +// +// Note: while StrCat supports appending up to 26 arguments, StrAppend +// is currently limited to 9. That's rarely an issue except when +// automatically transforming StrCat to StrAppend, and can easily be +// worked around as consecutive calls to StrAppend are quite efficient. +// ---------------------------------------------------------------------- + +void StrAppend(string *dest, const AlphaNum &a); +void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b); +void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c); +void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c, const AlphaNum &d); + +// Support 5 or more arguments +template +inline void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b, + const AlphaNum &c, const AlphaNum &d, const AlphaNum &e, + const AV &... args) { + internal::AppendPieces(dest, + {a.Piece(), b.Piece(), c.Piece(), d.Piece(), e.Piece(), + static_cast(args).Piece()...}); +} + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_STRINGS_STRCAT_H_ diff --git a/tensorflow/core/lib/strings/strcat_test.cc b/tensorflow/core/lib/strings/strcat_test.cc new file mode 100644 index 0000000000..9ff7d81af9 --- /dev/null +++ b/tensorflow/core/lib/strings/strcat_test.cc @@ -0,0 +1,324 @@ +#include "tensorflow/core/lib/strings/strcat.h" + +#include + +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/port.h" +#include + +namespace tensorflow { +namespace strings { + +// Test StrCat of ints and longs of various sizes and signdedness. +TEST(StrCat, Ints) { + const int16 s = -1; + const uint16 us = 2; + const int i = -3; + const unsigned int ui = 4; + const int32 l = -5; + const uint32 ul = 6; + const int64 ll = -7; + const uint64 ull = 8; + const ptrdiff_t ptrdiff = -9; + const size_t size = 10; + const ssize_t ssize = -11; + const intptr_t intptr = -12; + const uintptr_t uintptr = 13; + string answer; + answer = StrCat(s, us); + EXPECT_EQ(answer, "-12"); + answer = StrCat(i, ui); + EXPECT_EQ(answer, "-34"); + answer = StrCat(l, ul); + EXPECT_EQ(answer, "-56"); + answer = StrCat(ll, ull); + EXPECT_EQ(answer, "-78"); + answer = StrCat(ptrdiff, size); + EXPECT_EQ(answer, "-910"); + answer = StrCat(ssize, intptr); + EXPECT_EQ(answer, "-11-12"); + answer = StrCat(uintptr, 0); + EXPECT_EQ(answer, "130"); +} + +TEST(StrCat, Basics) { + string result; + + string strs[] = {"Hello", "Cruel", "World"}; + + StringPiece pieces[] = {"Hello", "Cruel", "World"}; + + const char *c_strs[] = {"Hello", "Cruel", "World"}; + + int32 i32s[] = {'H', 'C', 'W'}; + uint64 ui64s[] = {12345678910LL, 10987654321LL}; + + result = StrCat(false, true, 2, 3); + EXPECT_EQ(result, "0123"); + + result = StrCat(-1); + EXPECT_EQ(result, "-1"); + + result = StrCat(0.5); + EXPECT_EQ(result, "0.5"); + + result = StrCat(strs[1], pieces[2]); + EXPECT_EQ(result, "CruelWorld"); + + result = StrCat(strs[0], ", ", pieces[2]); + EXPECT_EQ(result, "Hello, World"); + + result = StrCat(strs[0], ", ", strs[1], " ", strs[2], "!"); + EXPECT_EQ(result, "Hello, Cruel World!"); + + result = StrCat(pieces[0], ", ", pieces[1], " ", pieces[2]); + EXPECT_EQ(result, "Hello, Cruel World"); + + result = StrCat(c_strs[0], ", ", c_strs[1], " ", c_strs[2]); + EXPECT_EQ(result, "Hello, Cruel World"); + + result = StrCat("ASCII ", i32s[0], ", ", i32s[1], " ", i32s[2], "!"); + EXPECT_EQ(result, "ASCII 72, 67 87!"); + + result = StrCat(ui64s[0], ", ", ui64s[1], "!"); + EXPECT_EQ(result, "12345678910, 10987654321!"); + + string one = "1"; // Actually, it's the size of this string that we want; a + // 64-bit build distinguishes between size_t and uint64, + // even though they're both unsigned 64-bit values. + result = StrCat("And a ", one.size(), " and a ", &result[2] - &result[0], + " and a ", one, " 2 3 4", "!"); + EXPECT_EQ(result, "And a 1 and a 2 and a 1 2 3 4!"); + + // result = StrCat("Single chars won't compile", '!'); + // result = StrCat("Neither will NULLs", NULL); + result = StrCat("To output a char by ASCII/numeric value, use +: ", '!' + 0); + EXPECT_EQ(result, "To output a char by ASCII/numeric value, use +: 33"); + + float f = 100000.5; + result = StrCat("A hundred K and a half is ", f); + EXPECT_EQ(result, "A hundred K and a half is 100000.5"); + + double d = f; + d *= d; + result = StrCat("A hundred K and a half squared is ", d); + EXPECT_EQ(result, "A hundred K and a half squared is 10000100000.25"); + + result = StrCat(1, 2, 333, 4444, 55555, 666666, 7777777, 88888888, 999999999); + EXPECT_EQ(result, "12333444455555666666777777788888888999999999"); +} + +TEST(StrCat, MaxArgs) { + string result; + // Test 10 up to 26 arguments, the current maximum + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a"); + EXPECT_EQ(result, "123456789a"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b"); + EXPECT_EQ(result, "123456789ab"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c"); + EXPECT_EQ(result, "123456789abc"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d"); + EXPECT_EQ(result, "123456789abcd"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e"); + EXPECT_EQ(result, "123456789abcde"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f"); + EXPECT_EQ(result, "123456789abcdef"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g"); + EXPECT_EQ(result, "123456789abcdefg"); + result = + StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", "h"); + EXPECT_EQ(result, "123456789abcdefgh"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i"); + EXPECT_EQ(result, "123456789abcdefghi"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j"); + EXPECT_EQ(result, "123456789abcdefghij"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k"); + EXPECT_EQ(result, "123456789abcdefghijk"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k", "l"); + EXPECT_EQ(result, "123456789abcdefghijkl"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k", "l", "m"); + EXPECT_EQ(result, "123456789abcdefghijklm"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k", "l", "m", "n"); + EXPECT_EQ(result, "123456789abcdefghijklmn"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k", "l", "m", "n", "o"); + EXPECT_EQ(result, "123456789abcdefghijklmno"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k", "l", "m", "n", "o", "p"); + EXPECT_EQ(result, "123456789abcdefghijklmnop"); + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", + "h", "i", "j", "k", "l", "m", "n", "o", "p", "q"); + EXPECT_EQ(result, "123456789abcdefghijklmnopq"); + // No limit thanks to C++11's variadic templates + result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "a", "b", "c", "d", "e", "f", + "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", + "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", "C", "D", + "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", + "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"); + EXPECT_EQ(result, + "12345678910abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"); +} + +TEST(StrAppend, Basics) { + string result = "existing text"; + + string strs[] = {"Hello", "Cruel", "World"}; + + StringPiece pieces[] = {"Hello", "Cruel", "World"}; + + const char *c_strs[] = {"Hello", "Cruel", "World"}; + + int32 i32s[] = {'H', 'C', 'W'}; + uint64 ui64s[] = {12345678910LL, 10987654321LL}; + + string::size_type old_size = result.size(); + StrAppend(&result, strs[0]); + EXPECT_EQ(result.substr(old_size), "Hello"); + + old_size = result.size(); + StrAppend(&result, strs[1], pieces[2]); + EXPECT_EQ(result.substr(old_size), "CruelWorld"); + + old_size = result.size(); + StrAppend(&result, strs[0], ", ", pieces[2]); + EXPECT_EQ(result.substr(old_size), "Hello, World"); + + old_size = result.size(); + StrAppend(&result, strs[0], ", ", strs[1], " ", strs[2], "!"); + EXPECT_EQ(result.substr(old_size), "Hello, Cruel World!"); + + old_size = result.size(); + StrAppend(&result, pieces[0], ", ", pieces[1], " ", pieces[2]); + EXPECT_EQ(result.substr(old_size), "Hello, Cruel World"); + + old_size = result.size(); + StrAppend(&result, c_strs[0], ", ", c_strs[1], " ", c_strs[2]); + EXPECT_EQ(result.substr(old_size), "Hello, Cruel World"); + + old_size = result.size(); + StrAppend(&result, "ASCII ", i32s[0], ", ", i32s[1], " ", i32s[2], "!"); + EXPECT_EQ(result.substr(old_size), "ASCII 72, 67 87!"); + + old_size = result.size(); + StrAppend(&result, ui64s[0], ", ", ui64s[1], "!"); + EXPECT_EQ(result.substr(old_size), "12345678910, 10987654321!"); + + string one = "1"; // Actually, it's the size of this string that we want; a + // 64-bit build distinguishes between size_t and uint64, + // even though they're both unsigned 64-bit values. + old_size = result.size(); + StrAppend(&result, "And a ", one.size(), " and a ", &result[2] - &result[0], + " and a ", one, " 2 3 4", "!"); + EXPECT_EQ(result.substr(old_size), "And a 1 and a 2 and a 1 2 3 4!"); + + // result = StrCat("Single chars won't compile", '!'); + // result = StrCat("Neither will NULLs", NULL); + old_size = result.size(); + StrAppend(&result, "To output a char by ASCII/numeric value, use +: ", + '!' + 0); + EXPECT_EQ(result.substr(old_size), + "To output a char by ASCII/numeric value, use +: 33"); + + float f = 100000.5; + old_size = result.size(); + StrAppend(&result, "A hundred K and a half is ", f); + EXPECT_EQ(result.substr(old_size), "A hundred K and a half is 100000.5"); + + double d = f; + d *= d; + old_size = result.size(); + StrAppend(&result, "A hundred K and a half squared is ", d); + EXPECT_EQ(result.substr(old_size), + "A hundred K and a half squared is 10000100000.25"); + + // Test 9 arguments, the old maximum + old_size = result.size(); + StrAppend(&result, 1, 22, 333, 4444, 55555, 666666, 7777777, 88888888, 9); + EXPECT_EQ(result.substr(old_size), "1223334444555556666667777777888888889"); + + // No limit thanks to C++11's variadic templates + old_size = result.size(); + StrAppend(&result, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "a", "b", "c", "d", "e", + "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", + "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", "C", "D", "E", + "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", + "S", "T", "U", "V", "W", "X", "Y", "Z", + "No limit thanks to C++11's variadic templates"); + EXPECT_EQ(result.substr(old_size), + "12345678910abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + "No limit thanks to C++11's variadic templates"); +} + +TEST(StrAppend, Death) { + string s = "self"; + EXPECT_DEBUG_DEATH(StrAppend(&s, s.c_str() + 1), "Check failed:"); + EXPECT_DEBUG_DEATH(StrAppend(&s, s), "Check failed:"); +} + +static void CheckHex64(uint64 v) { + using tensorflow::strings::Hex; + string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_16)); + string expected = Printf("%016llx", static_cast(v)); + EXPECT_EQ(expected, actual) << " decimal value " << v; + + actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8)); + expected = Printf("%08llx", static_cast(v)); + EXPECT_EQ(expected, actual) << " decimal value " << v; + + actual = StrCat(Hex(v)); + expected = Printf("%llx", static_cast(v)); + EXPECT_EQ(expected, actual) << " decimal value " << v; +} + +static void CheckHex32(uint32 v) { + using tensorflow::strings::Hex; + string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8)); + string expected = Printf("%08x", v); + EXPECT_EQ(expected, actual) << " decimal value " << v; + + actual = StrCat(Hex(v)); + expected = Printf("%x", v); + EXPECT_EQ(expected, actual) << " decimal value " << v; +} + +static void CheckHexSigned32(int32 v) { + using tensorflow::strings::Hex; + string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8)); + string expected = Printf("%08x", v); + EXPECT_EQ(expected, actual) << " decimal value " << v; + + actual = StrCat(Hex(v)); + expected = Printf("%x", v); + EXPECT_EQ(expected, actual) << " decimal value " << v; +} + +static void TestFastPrints() { + using tensorflow::strings::Hex; + + // Test min int to make sure that works + for (int i = 0; i < 10000; i++) { + CheckHex64(i); + CheckHex32(i); + CheckHexSigned32(i); + CheckHexSigned32(-i); + } + CheckHex64(0x123456789abcdef0ull); + CheckHex32(0x12345678); + + int8 minus_one_8bit = -1; + EXPECT_EQ("ff", StrCat(Hex(minus_one_8bit))); + + int16 minus_one_16bit = -1; + EXPECT_EQ("ffff", StrCat(Hex(minus_one_16bit))); +} + +TEST(Numbers, TestFunctionsMovedOverFromNumbersMain) { TestFastPrints(); } + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/stringprintf.cc b/tensorflow/core/lib/strings/stringprintf.cc new file mode 100644 index 0000000000..b354706cbd --- /dev/null +++ b/tensorflow/core/lib/strings/stringprintf.cc @@ -0,0 +1,85 @@ +#include "tensorflow/core/lib/strings/stringprintf.h" + +#include +#include // For va_list and related operations +#include // MSVC requires this for _vsnprintf +#include + +namespace tensorflow { +namespace strings { + +#ifdef COMPILER_MSVC +enum { IS_COMPILER_MSVC = 1 }; +#else +enum { IS_COMPILER_MSVC = 0 }; +#endif + +void Appendv(string* dst, const char* format, va_list ap) { + // First try with a small fixed size buffer + static const int kSpaceLength = 1024; + char space[kSpaceLength]; + + // It's possible for methods that use a va_list to invalidate + // the data in it upon use. The fix is to make a copy + // of the structure before using it and use that copy instead. + va_list backup_ap; + va_copy(backup_ap, ap); + int result = vsnprintf(space, kSpaceLength, format, backup_ap); + va_end(backup_ap); + + if (result < kSpaceLength) { + if (result >= 0) { + // Normal case -- everything fit. + dst->append(space, result); + return; + } + + if (IS_COMPILER_MSVC) { + // Error or MSVC running out of space. MSVC 8.0 and higher + // can be asked about space needed with the special idiom below: + va_copy(backup_ap, ap); + result = vsnprintf(NULL, 0, format, backup_ap); + va_end(backup_ap); + } + + if (result < 0) { + // Just an error. + return; + } + } + + // Increase the buffer size to the size requested by vsnprintf, + // plus one for the closing \0. + int length = result + 1; + char* buf = new char[length]; + + // Restore the va_list before we use it again + va_copy(backup_ap, ap); + result = vsnprintf(buf, length, format, backup_ap); + va_end(backup_ap); + + if (result >= 0 && result < length) { + // It fit + dst->append(buf, result); + } + delete[] buf; +} + +string Printf(const char* format, ...) { + va_list ap; + va_start(ap, format); + string result; + Appendv(&result, format, ap); + va_end(ap); + return result; +} + +void Appendf(string* dst, const char* format, ...) { + va_list ap; + va_start(ap, format); + Appendv(dst, format, ap); + va_end(ap); +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/stringprintf.h b/tensorflow/core/lib/strings/stringprintf.h new file mode 100644 index 0000000000..23ca2583ca --- /dev/null +++ b/tensorflow/core/lib/strings/stringprintf.h @@ -0,0 +1,37 @@ +// Printf variants that place their output in a C++ string. +// +// Usage: +// string result = strings::Printf("%d %s\n", 10, "hello"); +// strings::SPrintf(&result, "%d %s\n", 10, "hello"); +// strings::Appendf(&result, "%d %s\n", 20, "there"); + +#ifndef TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_ +#define TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace strings { + +// Return a C++ string +extern string Printf(const char* format, ...) + // Tell the compiler to do printf format string checking. + TF_PRINTF_ATTRIBUTE(1, 2); + +// Append result to a supplied string +extern void Appendf(string* dst, const char* format, ...) + // Tell the compiler to do printf format string checking. + TF_PRINTF_ATTRIBUTE(2, 3); + +// Lower-level routine that takes a va_list and appends to a specified +// string. All other routines are just convenience wrappers around it. +extern void Appendv(string* dst, const char* format, va_list ap); + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_ diff --git a/tensorflow/core/lib/strings/stringprintf_test.cc b/tensorflow/core/lib/strings/stringprintf_test.cc new file mode 100644 index 0000000000..737ed5c0e0 --- /dev/null +++ b/tensorflow/core/lib/strings/stringprintf_test.cc @@ -0,0 +1,113 @@ +#include "tensorflow/core/lib/strings/stringprintf.h" + +#include + +#include + +namespace tensorflow { +namespace strings { +namespace { + +TEST(PrintfTest, Empty) { + EXPECT_EQ("", Printf("%s", string().c_str())); + EXPECT_EQ("", Printf("%s", "")); +} + +TEST(PrintfTest, Misc) { +// MSVC does not support $ format specifier. +#if !defined(COMPILER_MSVC) + EXPECT_EQ("123hello w", Printf("%3$d%2$s %1$c", 'w', "hello", 123)); +#endif // !COMPILER_MSVC +} + +TEST(AppendfTest, Empty) { + string value("Hello"); + const char* empty = ""; + Appendf(&value, "%s", empty); + EXPECT_EQ("Hello", value); +} + +TEST(AppendfTest, EmptyString) { + string value("Hello"); + Appendf(&value, "%s", ""); + EXPECT_EQ("Hello", value); +} + +TEST(AppendfTest, String) { + string value("Hello"); + Appendf(&value, " %s", "World"); + EXPECT_EQ("Hello World", value); +} + +TEST(AppendfTest, Int) { + string value("Hello"); + Appendf(&value, " %d", 123); + EXPECT_EQ("Hello 123", value); +} + +TEST(PrintfTest, Multibyte) { + // If we are in multibyte mode and feed invalid multibyte sequence, + // Printf should return an empty string instead of running + // out of memory while trying to determine destination buffer size. + // see b/4194543. + + char* old_locale = setlocale(LC_CTYPE, NULL); + // Push locale with multibyte mode + setlocale(LC_CTYPE, "en_US.utf8"); + + const char kInvalidCodePoint[] = "\375\067s"; + string value = Printf("%.*s", 3, kInvalidCodePoint); + + // In some versions of glibc (e.g. eglibc-2.11.1, aka GRTEv2), snprintf + // returns error given an invalid codepoint. Other versions + // (e.g. eglibc-2.15, aka pre-GRTEv3) emit the codepoint verbatim. + // We test that the output is one of the above. + EXPECT_TRUE(value.empty() || value == kInvalidCodePoint); + + // Repeat with longer string, to make sure that the dynamically + // allocated path in StringAppendV is handled correctly. + int n = 2048; + char* buf = new char[n + 1]; + memset(buf, ' ', n - 3); + memcpy(buf + n - 3, kInvalidCodePoint, 4); + value = Printf("%.*s", n, buf); + // See GRTEv2 vs. GRTEv3 comment above. + EXPECT_TRUE(value.empty() || value == buf); + delete[] buf; + + setlocale(LC_CTYPE, old_locale); +} + +TEST(PrintfTest, NoMultibyte) { + // No multibyte handling, but the string contains funny chars. + char* old_locale = setlocale(LC_CTYPE, NULL); + setlocale(LC_CTYPE, "POSIX"); + string value = Printf("%.*s", 3, "\375\067s"); + setlocale(LC_CTYPE, old_locale); + EXPECT_EQ("\375\067s", value); +} + +TEST(PrintfTest, DontOverwriteErrno) { + // Check that errno isn't overwritten unless we're printing + // something significantly larger than what people are normally + // printing in their badly written PLOG() statements. + errno = ECHILD; + string value = Printf("Hello, %s!", "World"); + EXPECT_EQ(ECHILD, errno); +} + +TEST(PrintfTest, LargeBuf) { + // Check that the large buffer is handled correctly. + int n = 2048; + char* buf = new char[n + 1]; + memset(buf, ' ', n); + buf[n] = 0; + string value = Printf("%s", buf); + EXPECT_EQ(buf, value); + delete[] buf; +} + +} // namespace + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc new file mode 100644 index 0000000000..8c0571b50e --- /dev/null +++ b/tensorflow/core/ops/array_ops.cc @@ -0,0 +1,892 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("Pack") + .Input("values: N * T") + .Output("output: T") + .Attr("N: int >= 1") + .Attr("T: type") + .Doc(R"doc( +Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor. + +Packs the `N` tensors in `values` into a tensor with rank one higher than each +tensor in `values` and shape `[N] + values[0].shape`. The output satisfies +`output[i, ...] = values[i][...]`. + +This is the opposite of `unpack`. + +values: Must be of same shape and type. +output: The packed tensor. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Unpack") + .Input("value: T") + .Output("output: num * T") + .Attr("num: int >= 0") + .Attr("T: type") + .Doc(R"doc( +Unpacks the outer dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. + +Unpacks `num` tensors from `value` by chipping it along the first dimension. +The i'th tensor in `output` is the slice `value[i, ...]`. Each tensor in +`output` has shape `value.shape[1:]`. + +This is the opposite of `pack`. + +value: 1-D or higher, with first dimension `num`. +output: The list of tensors unpacked from `value`. +)doc"); + +// -------------------------------------------------------------------------- +// TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph +// in the N == 1 case to remove the node. +REGISTER_OP("Concat") + .Input("concat_dim: int32") + .Input("values: N * T") + .Output("output: T") + .Attr("N: int >= 2") + .Attr("T: type") + .Doc(R"doc( +Concatenates tensors along one dimension. + +concat_dim: 0-D. The dimension along which to concatenate. Must be in the + range [0, rank(values)). +values: The `N` Tensors to concatenate. Their ranks and types must match, + and their sizes must match in all dimensions except `concat_dim`. +output: A `Tensor` with the concatenation of values stacked along the + `concat_dim` dimension. This tensor's shape matches that of `values` except + in `concat_dim` where it has the sum of the sizes. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Split") + .Input("split_dim: int32") + .Input("value: T") + .Output("output: num_split * T") + .Attr("num_split: int >= 1") + .Attr("T: type") + .Doc(R"doc( +Splits a tensor into `num_split` tensors along one dimension. + +split_dim: 0-D. The dimension along which to split. Must be in the range + `[0, rank(value))`. +num_split: The number of ways to split. Must evenly divide + `value.shape[split_dim]`. +value: The tensor to split. +output: They are identically shaped tensors, whose shape matches that of `value` + except along `split_dim`, where their sizes are + `values.shape[split_dim] / num_split`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Const") + .Output("output: dtype") + .Attr("value: tensor") + .Attr("dtype: type") + .Doc(R"doc( +Returns a constant tensor. + +value: Attr `value` is the tensor to return. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("ZerosLike") + .Input("x: T") + .Output("y: T") + .Attr("T: type") + .Doc(R"doc( +Returns a tensor of zeros with the same shape and type as x. + +x: a tensor of type T. +y: a tensor of the same shape and type as x but filled with zeros. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Diag") + .Input("diagonal: T") + .Output("output: T") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Returns a diagonal tensor with a given diagonal values. + +Given a `diagonal`, this operation returns a tensor with the `diagonal` and +everything else padded with zeros. The diagonal is computed as follows: + +Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of +rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: + +`output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else. + +For example: + +```prettyprint +# 'diagonal' is [1, 2, 3, 4] +tf.diag(diagonal) ==> [[1, 0, 0, 0] + [0, 2, 0, 0] + [0, 0, 3, 0] + [0, 0, 0, 4]] +``` + +diagonal: Rank k tensor where k is at most 3. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Reverse") + .Input("tensor: T") + .Input("dims: bool") + .Output("output: T") + .Attr("T: {uint8, int8, int32, bool, float, double}") + .Doc(R"Doc( +Reverses specific dimensions of a tensor. + +Given a `tensor`, and a `bool` tensor `dims` representing the dimensions +of `tensor`, this operation reverses each dimension i of `tensor` where +`dims[i]` is `True`. + +`tensor` can have up to 8 dimensions. The number of dimensions +of `tensor` must equal the number of elements in `dims`. In other words: + +`rank(tensor) = size(dims)` + +For example: + +```prettyprint +# tensor 't' is [[[[ 0, 1, 2, 3], +# [ 4, 5, 6, 7], +# [ 8, 9, 10, 11]], +# [[12, 13, 14, 15], +# [16, 17, 18, 19], +# [20, 21, 22, 23]]]] +# tensor 't' shape is [1, 2, 3, 4] + +# 'dims' is [False, False, False, True] +reverse(t, dims) ==> [[[[ 3, 2, 1, 0], + [ 7, 6, 5, 4], + [ 11, 10, 9, 8]], + [[15, 14, 13, 12], + [19, 18, 17, 16], + [23, 22, 21, 20]]]] + +# 'dims' is [False, True, False, False] +reverse(t, dims) ==> [[[[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23] + [[ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11]]]] + +# 'dims' is [False, False, True, False] +reverse(t, dims) ==> [[[[8, 9, 10, 11], + [4, 5, 6, 7], + [0, 1, 2, 3]] + [[20, 21, 22, 23], + [16, 17, 18, 19], + [12, 13, 14, 15]]]] +``` + +tensor: Up to 8-D. +dims: 1-D. The dimensions to reverse. +output: The same shape as `tensor`. +)Doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("EditDistance") + .Input("hypothesis_indices: int64") + .Input("hypothesis_values: T") + .Input("hypothesis_shape: int64") + .Input("truth_indices: int64") + .Input("truth_values: T") + .Input("truth_shape: int64") + .Attr("normalize: bool = True") + .Attr("T: type") + .Output("output: float") + .Doc(R"doc( +Computes the (possibly normalized) Levenshtein Edit Distance. + +The inputs are variable-length sequences provided by SparseTensors + (hypothesis_indices, hypothesis_values, hypothesis_shape) +and + (truth_indices, truth_values, truth_shape). + +The inputs are: + +hypothesis_indices: The indices of the hypothesis list SparseTensor. + This is an N x R int64 matrix. +hypothesis_values: The values of the hypothesis list SparseTensor. + This is an N-length vector. +hypothesis_shape: The shape of the hypothesis list SparseTensor. + This is an R-length vector. +truth_indices: The indices of the truth list SparseTensor. + This is an M x R int64 matrix. +truth_values: The values of the truth list SparseTensor. + This is an M-length vector. +truth_shape: The shape of the truth list SparseTensor. + This is an R-length vector. +truth_shape: truth indices, vector. +normalize: boolean (if true, edit distances are normalized by length of truth). + +The output is: + +output: A dense float tensor with rank R - 1. + +For the example input: + + // hypothesis represents a 2x1 matrix with variable-length values: + // (0,0) = ["a"] + // (1,0) = ["b"] + hypothesis_indices = [[0, 0, 0], + [1, 0, 0]] + hypothesis_values = ["a", "b"] + hypothesis_shape = [2, 1, 1] + + // truth represents a 2x2 matrix with variable-length values: + // (0,0) = [] + // (0,1) = ["a"] + // (1,0) = ["b", "c"] + // (1,1) = ["a"] + truth_indices = [[0, 1, 0], + [1, 0, 0], + [1, 0, 1], + [1, 1, 0]] + truth_values = ["a", "b", "c", "a"] + truth_shape = [2, 2, 2] + normalize = true + +The output will be: + + // output is a 2x2 matrix with edit distances normalized by truth lengths. + output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis + [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Fill") + .Input("dims: int32") + .Input("value: T") + .Output("output: T") + .Attr("T: type") + .Doc(R"doc( +Creates a tensor filled with a scalar value. + +This operation creates a tensor of shape `dims` and fills it with `value`. + +For example: + +```prettyprint +# output tensor shape needs to be [2, 3] +# so 'dims' is [2, 3] +fill(dims, 9) ==> [[9, 9, 9] + [9, 9, 9]] +``` + +dims: 1-D. Represents the shape of the output tensor. +value: 0-D (scalar). Value to fill the returned tensor. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Gather") + .Input("params: Tparams") + .Input("indices: Tindices") + .Output("output: Tparams") + .Attr("Tparams: type") + .Attr("Tindices: {int32,int64}") + .Doc(R"doc( +Gather slices from `params` according to `indices`. + +`indices` must be an integer tensor of any dimension (usually 0-D or 1-D). +Produces an output tensor with shape `indices.shape + params.shape[1:]` where: + + # Scalar indices + output[:, ..., :] = params[indices, :, ... :] + + # Vector indices + output[i, :, ..., :] = params[indices[i], :, ... :] + + # Higher rank indices + output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] + +If `indices` is a permutation and `len(indices) == params.shape[0]` then +this operation will permute `params` accordingly. + +
+ +
+)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Identity") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .Doc(R"Doc( +Return a tensor with the same shape and contents as the input tensor or value. +)Doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("RefIdentity") + .Input("input: Ref(T)") + .Output("output: Ref(T)") + .Attr("T: type") + .Doc(R"Doc( +Return the same ref tensor as the input ref tensor. +)Doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("StopGradient") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .Doc(R"Doc( +Stops gradient computation. + +When executed in a graph, this op outputs its input tensor as-is. + +When building ops to compute gradients, this op prevents the contribution of +its inputs to be taken into account. Normally, the gradient generator adds ops +to a graph to compute the derivatives of a specified 'loss' by recursively +finding out inputs that contributed to its computation. If you insert this op +in the graph it inputs are masked from the gradient generator. They are not +taken into account for computing gradients. + +This is useful any time you want to compute a value with TensorFlow but need +to pretend that the value was a constant. Some examples include: + +* The *EM* algorithm where the *M-step* should not involve backpropagation + through the output of the *E-step*. +* Contrastive divergence training of Boltzmann machines where, when + differentiating the energy function, the training must not backpropagate + through the graph that generated the samples from the model. +* Adversarial training, where no backprop should happen through the adversarial + example generation process. +)Doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("CheckNumerics") + .Input("tensor: T") + .Output("output: T") + .Attr("T: {float, double}") + .Attr("message: string") + .Doc(R"doc( +Checks a tensor for NaN and Inf values. + +When run, reports an `InvalidArgument` error if `tensor` has any values +that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. + +message: Prefix of the error message. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Reshape") + .Input("tensor: T") + .Input("shape: int32") + .Output("output: T") + .Attr("T: type") + .Doc(R"Doc( +Reshapes a tensor. + +Given `tensor`, this operation returns a tensor that has the same values +as `tensor` with shape `shape`. + +If `shape` is the special value `[-1]`, then `tensor` is flattened and the +operation outputs a 1-D tensor with all elements of `tensor`. + +If `shape` is 1-D or higher, then the operation returns a tensor with shape +`shape` filled with the values of `tensor`. In this case, the number of elements +implied by `shape` must be the same as the number of elements in `tensor`. + +For example: + +```prettyprint +# tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9] +# tensor 't' has shape [9] +reshape(t, [3, 3]) ==> [[1, 2, 3] + [4, 5, 6] + [7, 8, 9]] + +# tensor 't' is [[[1, 1], [2, 2]] +# [[3, 3], [4, 4]]] +# tensor 't' has shape [2, 2] +reshape(t, [2, 4]) ==> [[1, 1, 2, 2] + [3, 3, 4, 4]] + +# tensor 't' is [[[1, 1, 1], +# [2, 2, 2]], +# [[3, 3, 3], +# [4, 4, 4]], +# [[5, 5, 5], +# [6, 6, 6]]] +# tensor 't' has shape [3, 2, 3] +# pass '[-1]' to flatten 't' +reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6] +``` + +shape: Defines the shape of the output tensor. +)Doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("InvertPermutation") + .Input("x: int32") + .Output("y: int32") + .Doc(R"doc( +Computes the inverse permutation of a tensor. + +This operation computes the inverse of an index permutation. It takes a 1-D +integer tensor `x`, which represents the indices of a zero-based array, and +swaps each value with its index position. In other words, for an ouput tensor +`y` and an input tensor `x`, this operation computes the following: + +`y[x[i]] = i for i in [0, 1, ..., len(x) - 1]` + +The values must include 0. There can be no duplicate values or negative values. + +For example: + +```prettyprint +# tensor `x` is [3, 4, 0, 2, 1] +invert_permutation(x) ==> [2, 4, 3, 0, 1] +``` + +x: 1-D. +y: 1-D. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Transpose") + .Input("x: T") + .Input("perm: int32") + .Output("y: T") + .Attr("T: type") + .Doc(R"doc( +Shuffle dimensions of x according to a permutation. + +The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: + `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Unique") + .Input("x: T") + .Output("y: T") + .Output("idx: int32") + .Attr("T: type") + .Doc(R"doc( +Finds unique elements in a 1-D tensor. + +This operation returns a tensor `y` containing all of the unique elements of `x` +sorted in the same order that they occur in `x`. This operation also returns a +tensor `idx` the same size as `x` that contains the index of each value of `x` +in the unique output `y`. In other words: + +`y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` + +For example: + +```prettyprint +# tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] +y, idx = unique(x) +y ==> [1, 2, 4, 7, 8] +idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] +``` + +x: 1-D. +y: 1-D. +idx: 1-D. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Shape") + .Input("input: T") + .Output("output: int32") + .Attr("T: type") + .Doc(R"doc( +Returns the shape of a tensor. + +This operation returns a 1-D integer tensor representing the shape of `input`. + +For example: + +```prettyprint +# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +shape(t) ==> [2, 2, 3] +``` + +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("ReverseSequence") + .Input("input: T") + .Input("seq_lengths: int64") + .Output("output: T") + .Attr("seq_dim: int") + .Attr("T: type") + .Doc(R"doc( +Reverses variable length slices in dimension `seq_dim`. + +This op first slices `input` along the first dimension, and for each slice `i`, +reverses the first `seq_lengths[i]` elements along the dimension `seq_dim`. + +The elements of `seq_lengths` must obey `seq_lengths[i] < input.dims[seq_dim]`, +and `seq_lengths` must be a vector of length `input.dims(0)`. + +The output slice `i` along dimension 0 is then given by input slice `i`, with +the first `seq_lengths[i]` slices along dimension `seq_dim` reversed. + +For example: + +```prettyprint +# Given this: +seq_dim = 1 +input.dims = (4, ...) +seq_lengths = [7, 2, 3, 5] + +# then slices of input are reversed on seq_dim, but only up to seq_lengths: +output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] +output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] +output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] +output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] + +# while entries past seq_lens are copied through: +output[0, 7:, :, ...] = input[0, 7:, :, ...] +output[1, 2:, :, ...] = input[1, 2:, :, ...] +output[2, 3:, :, ...] = input[2, 3:, :, ...] +output[3, 2:, :, ...] = input[3, 2:, :, ...] +``` + +input: The input to reverse. +seq_lengths: 1-D with length `input.dims(0)` and + `max(seq_lengths) < input.dims(seq_dim)` +seq_dim: The dimension which is partially reversed. +output: The partially reversed input. It has the same shape as `input`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Rank") + .Input("input: T") + .Output("output: int32") + .Attr("T: type") + .Doc(R"doc( +Returns the rank of a tensor. + +This operation returns an integer representing the rank of `input`. + +For example: + +```prettyprint +# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +# shape of tensor 't' is [2, 2, 3] +rank(t) ==> 3 +``` + +**Note**: The rank of a tensor is not the same as the rank of a matrix. The rank +of a tensor is the number of indices required to uniquely select each element +of the tensor. Rank is also known as "order", "degree", or "ndims." +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Size") + .Input("input: T") + .Output("output: int32") + .Attr("T: type") + .Doc(R"doc( +Returns the size of a tensor. + +This operation returns an integer representing the number of elements in +`input`. + +For example: + +```prettyprint +# 't' is [[[1, 1,, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]] +size(t) ==> 12 +``` + +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Slice") + .Input("input: T") + .Input("begin: Index") + .Input("size: Index") + .Output("output: T") + .Attr("T: type") + .Attr("Index: {int32,int64}") + .Doc(R"doc( +Return a slice from 'input'. + +The output tensor is a tensor with dimensions described by 'size' +whose values are extracted from 'input' starting at the offsets in +'begin'. + +*Requirements*: + 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n) + +begin: begin[i] specifies the offset into the 'i'th dimension of + 'input' to slice from. +size: size[i] specifies the number of elements of the 'i'th dimension + of 'input' to slice. If size[i] is -1, all remaining elements in dimension + i are included in the slice (i.e. this is equivalent to setting + size[i] = input.dim_size(i) - begin[i]). +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Tile") + .Input("input: T") + .Input("multiples: int32") + .Output("output: T") + .Attr("T: type") + .Doc(R"doc( +Constructs a tensor by tiling a given tensor. + +This operation creates a new tensor by replicating `input` `multiples` times. +The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements, +and the values of `input` are replicated `multiples[i]` times along the 'i'th +dimension. For example, tiling `[a b c d]` by `[2]` produces +`[a b c d a b c d]`. + +input: 1-D or higher. +multiples: 1-D. Length must be the same as the number of dimensions in `input` +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("TileGrad") + .Input("input: T") + .Input("multiples: int32") + .Output("output: T") + .Attr("T: type") + .Doc(R"doc( +Returns the gradient of `Tile`. + +Since `Tile` takes an input and repeats the input `multiples` times +along each dimension, `TileGrad` takes in `multiples` and aggregates +each repeated tile of `input` into `output`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Where") + .Input("input: bool") + .Output("index: int64") + .Doc(R"doc( +Returns locations of true values in a boolean tensor. + +This operation returns the coordinates of true elements in `input`. The +coordinates are returned in a 2-D tensor where the first dimension (rows) +represents the number of true elements, and the second dimension (columns) +represents the coordinates of the true elements. Keep in mind, the shape of +the output tensor can vary depending on how many true values there are in +`input`. Indices are output in row-major order. + +For example: + +```prettyprint +# 'input' tensor is [[True, False] +# [True, False]] +# 'input' has two true values, so output has two coordinates. +# 'input' has rank of 2, so coordinates have two indices. +where(input) ==> [[0, 0], + [1, 0]] + +# `input` tensor is [[[True, False] +# [True, False]] +# [[False, True] +# [False, True]] +# [[False, False] +# [False, True]]] +# 'input' has 5 true values, so output has 5 coordinates. +# 'input' has rank of 3, so coordinates have three indices. +where(input) ==> [[0, 0, 0], + [0, 1, 0], + [1, 0, 1], + [1, 1, 1], + [2, 1, 1]] +``` + +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("BroadcastGradientArgs") + .Input("s0: int32") + .Input("s1: int32") + .Output("r0: int32") + .Output("r1: int32") + .Doc(R"doc( +Return the reduction indices for computing gradients of s0 op s1 with broadcast. + +This is typically used by gradient computations for a broadcasting operation. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("Pad") + .Input("input: T") + .Input("paddings: int32") + .Output("output: T") + .Attr("T: type") + .Doc(R"doc( +Pads a tensor with zeros. + +This operation pads a `input` with zeros according to the `paddings` you +specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the +rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +how many zeros to add before the contents of `input` in that dimension, and +`paddings[D, 1]` indicates how many zeros to add after the contents of `input` +in that dimension. + +The padded size of each dimension D of the output is: + +`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` + +For example: + +```prettyprint +# 't' is [[1, 1], [2, 2]] +# 'paddings' is [[1, 1]], [2, 2]] +# rank of 't' is 2 +pad(t, paddings) ==> [[0, 0, 0, 0, 0] + [0, 0, 0, 0, 0] + [0, 1, 1, 0, 0] + [[0, 2, 2, 0, 0] + [0, 0, 0, 0, 0]] +``` + +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Placeholder") + .Output("output: dtype") + .Attr("dtype: type") + .Attr("shape: shape") + .Doc(R"doc( +A placeholder op for a value that will be fed into the computation. + +N.B. This operation will fail with an error if it is executed. It is +intended as a way to represent a value that will always be fed, and to +provide attrs that enable the fed value to be checked at runtime. + +output: A placeholder tensor that must be replaced using the feed mechanism. +dtype: The type of elements in the tensor. +shape: (Optional) The shape of the tensor. If the shape has 0 dimensions, the + shape is unconstrained. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("ExpandDims") + .Input("input: T") + .Input("dim: int32") + .Output("output: T") + .Attr("T: type") + .Doc(R"doc( +Inserts a dimension of 1 into a tensor's shape. + +Given a tensor `input`, this operation inserts a dimension of 1 at the +dimension index `dim` of `input`'s shape. The dimension index `dim` starts at +zero; if you specify a negative number for `dim` it is counted backward from +the end. + +This operation is useful if you want to add a batch dimension to a single +element. For example, if you have a single image of shape `[height, width, +channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`, +which will make the shape `[1, height, width, channels]`. + +Other examples: + +```prettyprint +# 't' is a tensor of shape [2] +shape(expand_dims(t, 0)) ==> [1, 2] +shape(expand_dims(t, 1)) ==> [2, 1] +shape(expand_dims(t, -1)) ==> [2, 1] + +# 't2' is a tensor of shape [2, 3, 5] +shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5] +shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5] +shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1] +``` + +This operation requires that: + +`-1-input.dims() <= dim <= input.dims()` + +This operation is related to `squeeze()`, which removes dimensions of +size 1. + +dim: 0-D (scalar). Specifies the dimension index at which to + expand the shape of `input`. +output: Contains the same data as `input`, but its shape has an additional + dimension of size 1 added. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Squeeze") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .Attr("squeeze_dims: list(int) >= 0 = []") + .Doc(R"doc( +Removes dimensions of size 1 from the shape of a tensor. + +Given a tensor `input`, this operation returns a tensor of the same type with +all dimensions of size 1 removed. If you don't want to remove all size 1 +dimensions, you can remove specific size 1 dimensions by specifying +`squeeze_dims`. + +For example: + +```prettyprint +# 't' is a tensor of shape [1, 2, 1, 3, 1, 1] +shape(squeeze(t)) ==> [2, 3] +``` + +Or, to remove specific size 1 dimensions: + +```prettyprint +# 't' is a tensor of shape [1, 2, 1, 3, 1, 1] +shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] +``` + +input: The `input` to squeeze. +squeeze_dims: If specified, only squeezes the dimensions listed. The dimension + index starts at 0. It is an error to squeeze a dimension that is not 1. +output: Contains the same data as `input`, but has one or more dimensions of + size 1 removed. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("ListDiff") + .Input("x: T") + .Input("y: T") + .Output("out: T") + .Output("idx: int32") + .Attr("T: type") + .Doc(R"doc( +Computes the difference between two lists of numbers. + +Given a list `x` and a list `y`, this operation returns a list `out` that +represents all numbers that are in `x` but not in `y`. The returned list `out` +is sorted in the same order that the numbers appear in `x` (duplicates are +preserved). This operation also returns a list `idx` that represents the +position of each `out` element in `x`. In other words: + +`out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]` + +For example, given this input: + +```prettyprint +x = [1, 2, 3, 4, 5, 6] +y = [1, 3, 5] +``` + +This operation would return: + +```prettyprint +out ==> [2, 4, 6] +idx ==> [1, 3, 5] +``` + +x: 1-D. Values to keep. +y: 1-D. Values to remove. +out: 1-D. Values present in `x` but not in `y`. +idx: 1-D. Positions of `x` values preserved in `out`. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/attention_ops.cc b/tensorflow/core/ops/attention_ops.cc new file mode 100644 index 0000000000..6fa9a6e821 --- /dev/null +++ b/tensorflow/core/ops/attention_ops.cc @@ -0,0 +1,54 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// Tout = extract_glimpse(Tin, size, offsets) extract the glimpse of size size +// centered at location offsets from the input tensor Tin +// +// REQUIRES: Tin.dims() == 4 +// +REGISTER_OP("ExtractGlimpse") + .Input("input: float") + .Input("size: int32") + .Input("offsets: float") + .Output("glimpse: float") + .Attr("centered: bool = true") + .Attr("normalized: bool = true") + .Attr("uniform_noise: bool = true") + .Doc(R"doc( +Extracts a glimpse from the input tensor. + +Returns a set of windows called glimpses extracted at location `offsets` +from the input tensor. If the windows only partially overlaps the inputs, the +non overlapping areas will be filled with random noise. + +The result is a 4-D tensor of shape `[batch_size, glimpse_height, +glimpse_width, channels]`. The channels and batch dimensions are the same as that +of the input tensor. The height and width of the output windows are +specified in the `size` parameter. + +The argument `normalized` and `centered` controls how the windows are built: +* If the coordinates are normalized but not centered, 0.0 and 1.0 + correspond to the minimum and maximum of each height and width dimension. +* If the coordinates are both normalized and centered, they range from -1.0 to + 1.0. The coordinates (-1.0, -1.0) correspond to the upper left corner, the + lower right corner is located at (1.0, 1.0) and the center is at (0, 0). +* If the coordinates are not normalized they are interpreted as numbers of pixels. + +input: A 4-D float tensor of shape `[batch_size, height, width, channels]`. +size: A 1-D tensor of 2 elements containing the size of the glimpses to extract. + The glimpse height must be specified first, following by the glimpse width. +offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing the x, y + locations of the center of each window. +glimpse: A tensor representing the glimpses `[batch_size, glimpse_height, + glimpse_width, channels]`. +centered: indicates if the offset coordinates are centered relative to + the image, in which case the (0, 0) offset is relative to the center of the + input images. If false, the (0,0) offset corresponds to the upper left corner + of the input images. +normalized: indicates if the offset coordinates are normalized. +uniform_noise: indicates if the noise should be generated using a + uniform distribution or a gaussian distribution. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/candidate_sampling_ops.cc b/tensorflow/core/ops/candidate_sampling_ops.cc new file mode 100644 index 0000000000..a98b0295ee --- /dev/null +++ b/tensorflow/core/ops/candidate_sampling_ops.cc @@ -0,0 +1,351 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("UniformCandidateSampler") + .Input("true_classes: int64") + .Output("sampled_candidates: int64") + .Output("true_expected_count: float") + .Output("sampled_expected_count: float") + .Attr("num_true: int >= 1") + .Attr("num_sampled: int >= 1") + .Attr("unique: bool") + .Attr("range_max: int >= 1") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Doc(R"doc( +Generates labels for candidate sampling with a uniform distribution. + +See explanations of candidate sampling and the data formats at +go/candidate-sampling. + +For each batch, this op picks a single set of sampled candidate labels. + +The advantages of sampling candidates per-batch are simplicity and the +possibility of efficient dense matrix multiplication. The disadvantage is that +the sampled candidates must be chosen independently of the context and of the +true labels. + +true_classes: A batch_size * num_true matrix, in which each row contains the + IDs of the num_true target_classes in the corresponding original label. +sampled_candidates: A vector of length num_sampled, in which each element is + the ID of a sampled candidate. +true_expected_count: A batch_size * num_true matrix, representing + the number of times each candidate is expected to occur in a batch + of sampled candidates. If unique=true, then this is a probability. +sampled_expected_count: A vector of length num_sampled, for each sampled + candidate represting the number of times the candidate is expected + to occur in a batch of sampled candidates. If unique=true, then this is a + probability. +num_true: Number of true labels per context. +num_sampled: Number of candidates to randomly sample per batch. +unique: If unique is true, we sample with rejection, so that all sampled + candidates in a batch are unique. This requires some approximation to + estimate the post-rejection sampling probabilities. +range_max: The sampler will sample integers from the interval [0, range_max). +seed: If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: An second seed to avoid seed collision. +)doc"); + +REGISTER_OP("LogUniformCandidateSampler") + .Input("true_classes: int64") + .Output("sampled_candidates: int64") + .Output("true_expected_count: float") + .Output("sampled_expected_count: float") + .Attr("num_true: int >= 1") + .Attr("num_sampled: int >= 1") + .Attr("unique: bool") + .Attr("range_max: int >= 1") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Doc(R"doc( +Generates labels for candidate sampling with a log-uniform distribution. + +See explanations of candidate sampling and the data formats at +go/candidate-sampling. + +For each batch, this op picks a single set of sampled candidate labels. + +The advantages of sampling candidates per-batch are simplicity and the +possibility of efficient dense matrix multiplication. The disadvantage is that +the sampled candidates must be chosen independently of the context and of the +true labels. + + +true_classes: A batch_size * num_true matrix, in which each row contains the + IDs of the num_true target_classes in the corresponding original label. +sampled_candidates: A vector of length num_sampled, in which each element is + the ID of a sampled candidate. +true_expected_count: A batch_size * num_true matrix, representing + the number of times each candidate is expected to occur in a batch + of sampled candidates. If unique=true, then this is a probability. +sampled_expected_count: A vector of length num_sampled, for each sampled + candidate represting the number of times the candidate is expected + to occur in a batch of sampled candidates. If unique=true, then this is a + probability. +num_true: Number of true labels per context. +num_sampled: Number of candidates to randomly sample per batch. +unique: If unique is true, we sample with rejection, so that all sampled + candidates in a batch are unique. This requires some approximation to + estimate the post-rejection sampling probabilities. +range_max: The sampler will sample integers from the interval [0, range_max). +seed: If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: An second seed to avoid seed collision. +)doc"); + +REGISTER_OP("LearnedUnigramCandidateSampler") + .Input("true_classes: int64") + .Output("sampled_candidates: int64") + .Output("true_expected_count: float") + .Output("sampled_expected_count: float") + .Attr("num_true: int >= 1") + .Attr("num_sampled: int >= 1") + .Attr("unique: bool") + .Attr("range_max: int >= 1") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Doc(R"doc( +Generates labels for candidate sampling with a learned unigram distribution. + +See explanations of candidate sampling and the data formats at +go/candidate-sampling. + +For each batch, this op picks a single set of sampled candidate labels. + +The advantages of sampling candidates per-batch are simplicity and the +possibility of efficient dense matrix multiplication. The disadvantage is that +the sampled candidates must be chosen independently of the context and of the +true labels. + +true_classes: A batch_size * num_true matrix, in which each row contains the + IDs of the num_true target_classes in the corresponding original label. +sampled_candidates: A vector of length num_sampled, in which each element is + the ID of a sampled candidate. +true_expected_count: A batch_size * num_true matrix, representing + the number of times each candidate is expected to occur in a batch + of sampled candidates. If unique=true, then this is a probability. +sampled_expected_count: A vector of length num_sampled, for each sampled + candidate represting the number of times the candidate is expected + to occur in a batch of sampled candidates. If unique=true, then this is a + probability. +num_true: Number of true labels per context. +num_sampled: Number of candidates to randomly sample per batch. +unique: If unique is true, we sample with rejection, so that all sampled + candidates in a batch are unique. This requires some approximation to + estimate the post-rejection sampling probabilities. +range_max: The sampler will sample integers from the interval [0, range_max). +seed: If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: An second seed to avoid seed collision. +)doc"); + +REGISTER_OP("ThreadUnsafeUnigramCandidateSampler") + .Input("true_classes: int64") + .Output("sampled_candidates: int64") + .Output("true_expected_count: float") + .Output("sampled_expected_count: float") + .Attr("num_true: int >= 1") + .Attr("num_sampled: int >= 1") + .Attr("unique: bool") + .Attr("range_max: int >= 1") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Doc(R"doc( +Generates labels for candidate sampling with a learned unigram distribution. + +See explanations of candidate sampling and the data formats at +go/candidate-sampling. + +For each batch, this op picks a single set of sampled candidate labels. + +The advantages of sampling candidates per-batch are simplicity and the +possibility of efficient dense matrix multiplication. The disadvantage is that +the sampled candidates must be chosen independently of the context and of the +true labels. + +true_classes: A batch_size * num_true matrix, in which each row contains the + IDs of the num_true target_classes in the corresponding original label. +sampled_candidates: A vector of length num_sampled, in which each element is + the ID of a sampled candidate. +true_expected_count: A batch_size * num_true matrix, representing + the number of times each candidate is expected to occur in a batch + of sampled candidates. If unique=true, then this is a probability. +sampled_expected_count: A vector of length num_sampled, for each sampled + candidate represting the number of times the candidate is expected + to occur in a batch of sampled candidates. If unique=true, then this is a + probability. +num_true: Number of true labels per context. +num_sampled: Number of candidates to randomly sample per batch. +unique: If unique is true, we sample with rejection, so that all sampled + candidates in a batch are unique. This requires some approximation to + estimate the post-rejection sampling probabilities. +range_max: The sampler will sample integers from the interval [0, range_max). +seed: If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: An second seed to avoid seed collision. +)doc"); + +REGISTER_OP("FixedUnigramCandidateSampler") + .Input("true_classes: int64") + .Output("sampled_candidates: int64") + .Output("true_expected_count: float") + .Output("sampled_expected_count: float") + .Attr("num_true: int >= 1") + .Attr("num_sampled: int >= 1") + .Attr("unique: bool") + .Attr("range_max: int >= 1") + .Attr("vocab_file: string = ''") + .Attr("distortion: float = 1.0") + .Attr("num_reserved_ids: int = 0") + .Attr("num_shards: int >= 1 = 1") + .Attr("shard: int >= 0 = 0") + .Attr("unigrams: list(float) = []") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Doc(R"doc( +Generates labels for candidate sampling with a learned unigram distribution. + +A unigram sampler could use a fixed unigram distribution read from a +file or passed in as an in-memory array instead of building up the distribution +from data on the fly. There is also an option to skew the distribution by +applying a distortion power to the weights. + +The vocabulary file should be in CSV-like format, with the last field +being the weight associated with the word. + +For each batch, this op picks a single set of sampled candidate labels. + +The advantages of sampling candidates per-batch are simplicity and the +possibility of efficient dense matrix multiplication. The disadvantage is that +the sampled candidates must be chosen independently of the context and of the +true labels. + +true_classes: A batch_size * num_true matrix, in which each row contains the + IDs of the num_true target_classes in the corresponding original label. +sampled_candidates: A vector of length num_sampled, in which each element is + the ID of a sampled candidate. +true_expected_count: A batch_size * num_true matrix, representing + the number of times each candidate is expected to occur in a batch + of sampled candidates. If unique=true, then this is a probability. +sampled_expected_count: A vector of length num_sampled, for each sampled + candidate represting the number of times the candidate is expected + to occur in a batch of sampled candidates. If unique=true, then this is a + probability. +num_true: Number of true labels per context. +num_sampled: Number of candidates to randomly sample per batch. +unique: If unique is true, we sample with rejection, so that all sampled + candidates in a batch are unique. This requires some approximation to + estimate the post-rejection sampling probabilities. +range_max: The sampler will sample integers from the interval [0, range_max). +vocab_file: Each valid line in this file (which should have a CSV-like format) + corresponds to a valid word ID. IDs are in sequential order, starting from + num_reserved_ids. The last entry in each line is expected to be a value + corresponding to the count or relative probability. Exactly one of vocab_file + and unigrams needs to be passed to this op. +distortion: The distortion is used to skew the unigram probability distribution. + Each weight is first raised to the distortion's power before adding to the + internal unigram distribution. As a result, distortion = 1.0 gives regular + unigram sampling (as defined by the vocab file), and distortion = 0.0 gives + a uniform distribution. +num_reserved_ids: Optionally some reserved IDs can be added in the range [0, + ..., num_reserved_ids) by the users. One use case is that a special unknown + word token is used as ID 0. These IDs will have a sampling probability of 0. +num_shards: A sampler can be used to sample from a subset of the original range + in order to speed up the whole computation through parallelism. This parameter + (together with 'shard') indicates the number of partitions that are being + used in the overall computation. +shard: A sampler can be used to sample from a subset of the original range + in order to speed up the whole computation through parallelism. This parameter + (together with 'num_shards') indicates the particular partition number of a + sampler op, when partitioning is being used. +unigrams: A list of unigram counts or probabilities, one per ID in sequential + order. Exactly one of vocab_file and unigrams should be passed to this op. +seed: If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: An second seed to avoid seed collision. +)doc"); + +REGISTER_OP("AllCandidateSampler") + .Input("true_classes: int64") + .Output("sampled_candidates: int64") + .Output("true_expected_count: float") + .Output("sampled_expected_count: float") + .Attr("num_true: int >= 1") + .Attr("num_sampled: int >= 1") + .Attr("unique: bool") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Doc(R"doc( +Generates labels for candidate sampling with a learned unigram distribution. + +See explanations of candidate sampling and the data formats at +go/candidate-sampling. + +For each batch, this op picks a single set of sampled candidate labels. + +The advantages of sampling candidates per-batch are simplicity and the +possibility of efficient dense matrix multiplication. The disadvantage is that +the sampled candidates must be chosen independently of the context and of the +true labels. + +true_classes: A batch_size * num_true matrix, in which each row contains the + IDs of the num_true target_classes in the corresponding original label. +sampled_candidates: A vector of length num_sampled, in which each element is + the ID of a sampled candidate. +true_expected_count: A batch_size * num_true matrix, representing + the number of times each candidate is expected to occur in a batch + of sampled candidates. If unique=true, then this is a probability. +sampled_expected_count: A vector of length num_sampled, for each sampled + candidate represting the number of times the candidate is expected + to occur in a batch of sampled candidates. If unique=true, then this is a + probability. +num_true: Number of true labels per context. +num_sampled: Number of candidates to produce per batch. +unique: If unique is true, we sample with rejection, so that all sampled + candidates in a batch are unique. This requires some approximation to + estimate the post-rejection sampling probabilities. +seed: If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: An second seed to avoid seed collision. +)doc"); + +REGISTER_OP("ComputeAccidentalHits") + .Input("true_classes: int64") + .Input("sampled_candidates: int64") + .Output("indices: int32") + .Output("ids: int64") + .Output("weights: float") + .Attr("num_true: int") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Doc(R"doc( +Computes the ids of the positions in sampled_candidates that match true_labels. + +When doing log-odds NCE, the result of this op should be passed through a +SparseToDense op, then added to the logits of the sampled candidates. This has +the effect of 'removing' the sampled labels that match the true labels by +making the classifier sure that they are sampled labels. + +true_classes: The true_classes output of UnpackSparseLabels. +sampled_candidates: The sampled_candidates output of CandidateSampler. +indices: A vector of indices corresponding to rows of true_candidates. +ids: A vector of IDs of positions in sampled_candidates that match a true_label + for the row with the corresponding index in indices. +weights: A vector of the same length as indices and ids, in which each element + is -FLOAT_MAX. +num_true: Number of true labels per context. +seed: If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: An second seed to avoid seed collision. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/control_flow_ops.cc b/tensorflow/core/ops/control_flow_ops.cc new file mode 100644 index 0000000000..517b2d2742 --- /dev/null +++ b/tensorflow/core/ops/control_flow_ops.cc @@ -0,0 +1,179 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// -------------------------------------------------------------------------- +REGISTER_OP("Switch") + .Input("data: T") + .Input("pred: bool") + .Output("output_false: T") + .Output("output_true: T") + .Attr("T: type") + .Doc(R"doc( +Forwards `data` to the output port determined by `pred`. + +If `pred` is true, the `data` input is forwared to `output_true`. Otherwise, +the data goes to `output_false`. + +See also `RefSwitch` and `Merge`. + +data: The tensor to be forwarded to the appropriate output. +pred: A scalar that specifies which output port will receive data. +output_false: If `pred` is false, data will be forwarded to this output. +output_true: If `pred` is true, data will be forwarded to this output. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("RefSwitch") + .Input("data: Ref(T)") + .Input("pred: bool") + .Output("output_false: Ref(T)") + .Output("output_true: Ref(T)") + .Attr("T: type") + .Doc(R"doc( +Forwards the ref tensor `data` to the output port determined by `pred`. + +If `pred` is true, the `data` input is forwared to `output_true`. Otherwise, +the data goes to `output_false`. + +See also `Switch` and `Merge`. + +data: The ref tensor to be forwarded to the appropriate output. +pred: A scalar that specifies which output port will receive data. +output_false: If `pred` is false, data will be forwarded to this output. +output_true: If `pred` is true, data will be forwarded to this output. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("RefSelect") + .Input("index: int32") + .Input("inputs: Ref(N * T)") + .Output("output: Ref(T)") + .Attr("T: type") + .Attr("N: int >= 1") + .Doc(R"doc( +Forwards the `index`th element of `inputs` to `output`. + +index: A scalar that determines the input that gets selected. +inputs: A list of ref tensors, one of which will be forwarded to `output`. +output: The forwarded tensor. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Merge") + .Input("inputs: N * T") + .Output("output: T") + .Output("value_index: int32") + .Attr("T: type") + .Attr("N: int >= 1") + .Doc(R"doc( +Forwards the value of an available tensor from `inputs` to `output`. + +`Merge` waits for at least one of the tensors in `inputs` to become available. +It is usually combined with `Switch` to implement branching. + +`Merge` forwards the first tensor for become available to `output`, and sets +`value_index` to its index in `inputs`. + +It is an error if more than one tensor in `inputs` is available. + +inputs: The input tensors, exactly one of which will become available. +output: Will be set to the available input tensor. +value_index: The index of the chosen input tensor in `inputs`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Enter") + .Input("data: T") + .Output("output: T") + .Attr("T: type") + .Attr("frame_name: string") + .Attr("is_constant: bool = false") + .Attr("parallel_iterations: int = 10") + .Doc(R"doc( +Creates or finds a child frame, and makes `data` available to the child frame. + +This op is used together with `Exit` to create loops in the graph. +The unique `frame_name` is used by the `Executor` to identify frames. If +`is_constant` is true, `output` is a constant in the child frame; otherwise +it may be changed in the child frame. At most `parallel_iterations` iterations +are run in parallel in the child frame. + +data: The tensor to be made available to the child frame. +frame_name: The name of the child frame. +is_constant: If true, the output is constant within the child frame. +parallel_iterations: The number of iterations allowed to run in parallel. +output: The same tensor as `data`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("RefEnter") + .Input("data: Ref(T)") + .Output("output: Ref(T)") + .Attr("T: type") + .Attr("frame_name: string") + .Attr("is_constant: bool = false") + .Attr("parallel_iterations: int = 10") + .Doc(R"doc( +Creates or finds a child frame, and makes `data` available to the child frame. + +The unique `frame_name` is used by the `Executor` to identify frames. If +`is_constant` is true, `output` is a constant in the child frame; otherwise +it may be changed in the child frame. At most `parallel_iterations` iterations +are run in parallel in the child frame. + +data: The tensor to be made available to the child frame. +frame_name: The name of the child frame. +is_constant: If true, the output is constant within the child frame. +parallel_iterations: The number of iterations allowed to run in parallel. +output: The same tensor as `data`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Exit") + .Input("data: T") + .Output("output: T") + .Attr("T: type") + .Doc(R"doc( +Exits the current frame to its parent frame. + +Exit makes its input `data` available to the parent frame. + +data: The tensor to be made available to the parent frame. +output: The same tensor as `data`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("NextIteration") + .Input("data: T") + .Output("output: T") + .Attr("T: type") + .Doc(R"doc( +Makes its input available to the next iteration. + +data: The tensor to be made available to the next iteration. +output: The same tensor as `data`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("LoopCond") + .Input("input: bool") + .Output("output: bool") + .Doc(R"doc( +Forwards the input to the output. + +This operator represents the loop termination condition used by the +"pivot" switches of a loop. + +input:= A boolean scalar, representing the branch predicate of the Switch op. +output: The same tensor as `input`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("ControlTrigger") + .Doc(R"doc( +Does nothing. Serves as a control trigger for scheduling. Only useful as a +placeholder for control edges. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc new file mode 100644 index 0000000000..49eba33188 --- /dev/null +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -0,0 +1,357 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// -------------------------------------------------------------------------- + +REGISTER_OP("DynamicPartition") + .Input("data: T") + .Input("partitions: int32") + .Output("outputs: num_partitions * T") + .Attr("num_partitions: int") + .Attr("T: type") + .Doc(R"doc( +Partitions `data` into `num_partitions` tensors using indices from `partitions`. + +For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` +becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i` +are placed in `outputs[i]` in lexicographic order of `js`, and the first +dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`. +In detail, + + outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] + + outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) + +`data.shape` must start with `partitions.shape`. + +For example: + + # Scalar partitions + partitions = 1 + num_partitions = 2 + data = [10, 20] + outputs[0] = [] # Empty with shape [0, 2] + outputs[1] = [[10, 20]] + + # Vector partitions + partitions = [0, 0, 1, 1, 0] + num_partitions = 2 + data = [10, 20, 30, 40, 50] + outputs[0] = [10, 20, 50] + outputs[1] = [30, 40] + +
+ +
+ +partitions: Any shape. Indices in the range `[0, num_partitions)`. +num_partitions: The number of partitions to output. +)doc"); + +REGISTER_OP("DynamicStitch") + .Input("indices: N * int32") + .Input("data: N * T") + .Output("merged: T") + .Attr("N : int >= 2") + .Attr("T : type") + .Doc(R"doc( +Interleave the values from the `data` tensors into a single tensor. + +Builds a merged tensor such that + + merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] + +For example, if each `indices[m]` is scalar or vector, we have + + # Scalar indices + merged[indices[m], ...] = data[m][...] + + # Vector indices + merged[indices[m][i], ...] = data[m][i, ...] + +Each `data[i].shape` must start with the corresponding `indices[i].shape`, +and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we +must have `data[i].shape = indices[i].shape + constant`. In terms of this +`constant`, the output shape is + + merged.shape = [max(indices)] + constant + +Values are merged in order, so if an index appears in both `indices[m][i]` and +`indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the +merged result. + +For example: + + indices[0] = 6 + indices[1] = [4, 1] + indices[2] = [[5, 2], [0, 3]] + data[0] = [61, 62] + data[1] = [[41, 42], [11, 12]] + data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] + merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], + [51, 52], [61, 62]] + +
+ +
+)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("RandomShuffleQueue") + .Output("handle: Ref(string)") + .Attr("component_types: list(type) >= 1") + .Attr("shapes: list(shape) >= 0 = []") + .Attr("capacity: int = -1") + .Attr("min_after_dequeue: int = 0") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .Doc(R"doc( +A queue that randomizes the order of elements. + +handle: The handle to the queue. +component_types: The type of each component in a value. +shapes: The shape of each component in a value. The length of this attr must + be either 0 or the same as the length of component_types. If the length of + this attr is 0, the shapes of queue elements are not constrained, and + only one element may be dequeued at a time. +capacity: The upper bound on the number of elements in this queue. + Negative numbers mean no limit. +min_after_dequeue: Dequeue will block unless there would be this + many elements after the dequeue or the queue is closed. This + ensures a minimum level of mixing of elements. +seed: If either seed or seed2 is set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, a random seed is used. +seed2: A second seed to avoid seed collision. +container: If non-empty, this queue is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this queue will be shared under the given name + across multiple sessions. +)doc"); + +REGISTER_OP("FIFOQueue") + .Output("handle: Ref(string)") + .Attr("component_types: list(type) >= 1") + .Attr("shapes: list(shape) >= 0 = []") + .Attr("capacity: int = -1") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .Doc(R"doc( +A queue that produces elements in first-in first-out order. + +handle: The handle to the queue. +component_types: The type of each component in a value. +shapes: The shape of each component in a value. The length of this attr must + be either 0 or the same as the length of component_types. If the length of + this attr is 0, the shapes of queue elements are not constrained, and + only one element may be dequeued at a time. +capacity: The upper bound on the number of elements in this queue. + Negative numbers mean no limit. +container: If non-empty, this queue is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this queue will be shared under the given name + across multiple sessions. +)doc"); + +REGISTER_OP("QueueEnqueue") + .Input("handle: Ref(string)") + .Input("components: Tcomponents") + .Attr("Tcomponents: list(type) >= 1") + .Attr("timeout_ms: int = -1") + .Doc(R"doc( +Enqueues a tuple of one or more tensors in the given queue. + +The components input has k elements, which correspond to the components of +tuples stored in the given queue. + +N.B. If the queue is full, this operation will block until the given +element has been enqueued (or 'timeout_ms' elapses, if specified). + +handle: The handle to a queue. +components: One or more tensors from which the enqueued tensors should be taken. +timeout_ms: If the queue is full, this operation will block for up to + timeout_ms milliseconds. + Note: This option is not supported yet. +)doc"); + +REGISTER_OP("QueueEnqueueMany") + .Input("handle: Ref(string)") + .Input("components: Tcomponents") + .Attr("Tcomponents: list(type) >= 1") + .Attr("timeout_ms: int = -1") + .Doc(R"doc( +Enqueues zero or more tuples of one or more tensors in the given queue. + +This operation slices each component tensor along the 0th dimension to +make multiple queue elements. All of the tuple components must have the +same size in the 0th dimension. + +The components input has k elements, which correspond to the components of +tuples stored in the given queue. + +N.B. If the queue is full, this operation will block until the given +elements have been enqueued (or 'timeout_ms' elapses, if specified). + +handle: The handle to a queue. +components: One or more tensors from which the enqueued tensors should + be taken. +timeout_ms: If the queue is too full, this operation will block for up + to timeout_ms milliseconds. + Note: This option is not supported yet. +)doc"); + +REGISTER_OP("QueueDequeue") + .Input("handle: Ref(string)") + .Output("components: component_types") + .Attr("component_types: list(type) >= 1") + .Attr("timeout_ms: int = -1") + .Doc(R"doc( +Dequeues a tuple of one or more tensors from the given queue. + +This operation has k outputs, where k is the number of components +in the tuples stored in the given queue, and output i is the ith +component of the dequeued tuple. + +N.B. If the queue is empty, this operation will block until an element +has been dequeued (or 'timeout_ms' elapses, if specified). + +handle: The handle to a queue. +components: One or more tensors that were dequeued as a tuple. +component_types: The type of each component in a tuple. +timeout_ms: If the queue is empty, this operation will block for up to + timeout_ms milliseconds. + Note: This option is not supported yet. +)doc"); + +REGISTER_OP("QueueDequeueMany") + .Input("handle: Ref(string)") + .Input("n: int32") + .Output("components: component_types") + .Attr("component_types: list(type) >= 1") + .Attr("timeout_ms: int = -1") + .Doc(R"doc( +Dequeues n tuples of one or more tensors from the given queue. + +This operation concatenates queue-element component tensors along the +0th dimension to make a single component tensor. All of the components +in the dequeued tuple will have size n in the 0th dimension. + +This operation has k outputs, where k is the number of components in +the tuples stored in the given queue, and output i is the ith +component of the dequeued tuple. + +N.B. If the queue is empty, this operation will block until n elements +have been dequeued (or 'timeout_ms' elapses, if specified). + +handle: The handle to a queue. +n: The number of tuples to dequeue. +components: One or more tensors that were dequeued as a tuple. +component_types: The type of each component in a tuple. +timeout_ms: If the queue has fewer than n elements, this operation + will block for up to timeout_ms milliseconds. + Note: This option is not supported yet. +)doc"); + +REGISTER_OP("QueueClose") + .Input("handle: Ref(string)") + .Attr("cancel_pending_enqueues: bool = false") + .Doc(R"doc( +Closes the given queue. + +This operation signals that no more elements will be enqueued in the +given queue. Subsequent Enqueue(Many) operations will fail. +Subsequent Dequeue(Many) operations will continue to succeed if +sufficient elements remain in the queue. Subsequent Dequeue(Many) +operations that would block will fail immediately. + +handle: The handle to a queue. +cancel_pending_enqueues: If true, all pending enqueue requests that are + blocked on the given queue will be cancelled. +)doc"); + +REGISTER_OP("QueueSize") + .Input("handle: Ref(string)") + .Output("size: int32") + .Doc(R"doc( +Computes the number of elements in the given queue. + +handle: The handle to a queue. +size: The number of elements in the given queue. +)doc"); + + +// -------------------------------------------------------------------------- + +REGISTER_OP("LookupTableFind") + .Input("table_handle: Ref(string)") + .Input("input_values: Tin") + .Input("default_value: Tout") + .Output("output_values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .Doc(R"doc( +Maps elements of a tensor into associated values given a lookup table. + +If an element of the input_values is not present in the table, the +specified default_value is used. + +The table needs to be initialized and the input and output types correspond +to the table key and value types. + +table_handle: A handle for a lookup table. +input_values: A vector of key values. +default_value: A scalar to return if the input is not found in the table. +output_values: A vector of values associated to the inputs. +)doc"); + +REGISTER_OP("LookupTableSize") + .Input("table_handle: Ref(string)") + .Output("size: int64") + .Doc(R"doc( +Computes the number of elements in the given table. + +table_handle: The handle to a lookup table. +size: The number of elements in the given table. +)doc"); + +REGISTER_OP("HashTable") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Doc(R"doc( +Creates and holds an immutable hash table. + +The key and value types can be specified. After initialization, the table +becomes immutable. + +table_handle: a handle of a the lookup table. +container: If non-empty, this hash table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this hash table is shared under the given name across + multiple sessions. +key_dtype: the type of the table key. +value_dtype: the type of the table value. +)doc"); + +REGISTER_OP("InitializeTable") + .Input("table_handle: Ref(string)") + .Input("keys: Tkey") + .Input("values: Tval") + .Attr("Tkey: type") + .Attr("Tval: type") + .Doc(R"doc( +Table initializer that takes two tensors for keys and values respectively. + +table_handle: a handle of the lookup table to be initialized. +keys: a vector of keys of type Tkey. +values: a vector of values of type Tval. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc new file mode 100644 index 0000000000..88af081893 --- /dev/null +++ b/tensorflow/core/ops/image_ops.cc @@ -0,0 +1,273 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// -------------------------------------------------------------------------- +REGISTER_OP("ResizeArea") + .Input("images: T") + .Input("size: int32") + .Output("resized_images: float") + .Attr("T: {uint8, int8, int32, float, double}") + .Doc(R"doc( +Resize `images` to `size` using area interpolation. + +Input images can be of different types but output images are always float. + +images: 4-D with shape `[batch, height, width, channels]`. +size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The + new size for the images. +resized_images: 4-D with shape + `[batch, new_height, new_width, channels]`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("ResizeBicubic") + .Input("images: T") + .Input("size: int32") + .Output("resized_images: float") + .Attr("T: {uint8, int8, int32, float, double}") + .Doc(R"doc( +Resize `images` to `size` using bicubic interpolation. + +Input images can be of different types but output images are always float. + +images: 4-D with shape `[batch, height, width, channels]`. +size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The + new size for the images. +resized_images: 4-D with shape + `[batch, new_height, new_width, channels]`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("ResizeBilinear") + .Input("images: T") + .Input("size: int32") + .Output("resized_images: float") + .Attr("T: {uint8, int8, int32, float, double}") + .Doc(R"doc( +Resize `images` to `size` using bilinear interpolation. + +Input images can be of different types but output images are always float. + +images: 4-D with shape `[batch, height, width, channels]`. +size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The + new size for the images. +resized_images: 4-D with shape + `[batch, new_height, new_width, channels]`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("ResizeNearestNeighbor") + .Input("images: T") + .Input("size: int32") + .Output("resized_images: T") + .Attr("T: {uint8, int8, int32, float, double}") + .Doc(R"doc( +Resize `images` to `size` using nearest neighbor interpolation. + +Input images can be of different types but output images are always float. + +images: 4-D with shape `[batch, height, width, channels]`. +size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The + new size for the images. +resized_images: 4-D with shape + `[batch, new_height, new_width, channels]`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("RandomCrop") + .Input("image: T") + .Input("size: int64") + .Output("output: T") + .Attr("T: {uint8, int8, int16, int32, int64, float, double}") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .SetIsStateful() + .Doc(R"doc( +Randomly crop `image`. + +`size` is a 1-D int64 tensor with 2 elements representing the crop height and +width. The values must be non negative. + +This Op picks a random location in `image` and crops a `height` by `width` +rectangle from that location. The random location is picked so the cropped +area will fit inside the original image. + +image: 3-D of shape `[height, width, channels]`. +size: 1-D of length 2 containing: `crop_height`, `crop_width`.. +seed: If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: An second seed to avoid seed collision. +output: 3-D of shape `[crop_height, crop_width, channels].` +)doc"); +// TODO(shlens): Support variable rank in RandomCrop. + +// -------------------------------------------------------------------------- +REGISTER_OP("DecodeJpeg") + .Input("contents: string") + .Attr("channels: int = 0") + .Attr("ratio: int = 1") + .Attr("fancy_upscaling: bool = true") + .Attr("try_recover_truncated: bool = false") + .Attr("acceptable_fraction: float = 1.0") + .Output("image: uint8") + .Doc(R"doc( +Decode a JPEG-encoded image to a uint8 tensor. + +The attr `channels` indicates the desired number of color channels for the +decoded image. + +Accepted values are: + +* 0: Use the number of channels in the JPEG-encoded image. +* 1: output a grayscale image. +* 3: output an RGB image. + +If needed, the JPEG-encoded image is transformed to match the requested number +of color channels. + +The attr `ratio` allows downscaling the image by an integer factor during +decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +downscaling the image later. + +contents: 0-D. The JPEG-encoded image. +channels: Number of color channels for the decoded image. +ratio: Downscaling ratio. +fancy_upscaling: If true use a slower but nicer upscaling of the + chroma planes (yuv420/422 only). +try_recover_truncated: If true try to recover an image from truncated input. +acceptable_fraction: The minimum required fraction of lines before a truncated + input is accepted. +image: 3-D with shape `[height, width, channels]`.. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("EncodeJpeg") + .Input("image: uint8") + .Attr("format: {'', 'grayscale', 'rgb'} = ''") + .Attr("quality: int = 95") + .Attr("progressive: bool = false") + .Attr("optimize_size: bool = false") + .Attr("chroma_downsampling: bool = true") + .Attr("density_unit: {'in', 'cm'} = 'in'") + .Attr("x_density: int = 300") + .Attr("y_density: int = 300") + .Attr("xmp_metadata: string = ''") + .Output("contents: string") + .Doc(R"doc( +JPEG-encode an image. + +`image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. + +The attr `format` can be used to override the color format of the encoded +output. Values can be: + +* `''`: Use a default format based on the number of channels in the image. +* `grayscale`: Output a grayscale JPEG image. The `channels` dimension + of `image` must be 1. +* `rgb`: Output an RGB JPEG image. The `channels` dimension + of `image` must be 3. + +If `format` is not specified or is the empty string, a default format is picked +in function of the number of channels in `image`: + +* 1: Output a grayscale image. +* 3: Output an RGB image. + +image: 3-D with shape `[height, width, channels]`. +format: Per pixel image format. +quality: Quality of the compression from 0 to 100 (higher is better and slower). +progressive: If True, create a JPEG that loads progressively (coarse to fine). +optimize_size: If True, spend CPU/RAM to reduce size with no quality change. +chroma_downsampling: See http://en.wikipedia.org/wiki/Chroma_subsampling. +density_unit: Unit used to specify `x_density` and `y_density`: + pixels per inch (`'in'`) or centimeter (`'cm'`). +x_density: Horizontal pixels per density unit. +y_density: Vertical pixels per density unit. +xmp_metadata: If not empty, embed this XMP metadata in the image header. +contents: 0-D. JPEG-encoded image. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("AdjustContrast") + .Input("images: T") + .Input("contrast_factor: float") + .Input("min_value: float") + .Input("max_value: float") + .Output("output: float") + .Attr("T: {uint8, int8, int16, int32, int64, float, double}") + .Doc(R"Doc( +Adjust the contrast of one or more images. + +`images` is a tensor of at least 3 dimensions. The last 3 dimensions are +interpreted as `[height, width, channels]`. The other dimensions only +represent a collection of images, such as `[batch, height, width, channels].` + +Contrast is adjusted independently for each channel of each image. + +For each channel, the Op first computes the mean of the image pixels in the +channel and then adjusts each component of each pixel to +`(x - mean) * contrast_factor + mean`. + +These adjusted values are then clipped to fit in the `[min_value, max_value]` +interval. + +`images: Images to adjust. At least 3-D. +contrast_factor: A float multiplier for adjusting contrast. +min_value: Minimum value for clipping the adjusted pixels. +max_value: Maximum value for clipping the adjusted pixels. +output: The constrast-adjusted image or images. +)Doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("DecodePng") + .Input("contents: string") + .Attr("channels: int = 0") + .Output("image: uint8") + .Doc(R"doc( +Decode a PNG-encoded image to a uint8 tensor. + +The attr `channels` indicates the desired number of color channels for the +decoded image. + +Accepted values are: + +* 0: Use the number of channels in the PNG-encoded image. +* 1: output a grayscale image. +* 3: output an RGB image. +* 4: output an RGBA image. + +If needed, the PNG-encoded image is transformed to match the requested number +of color channels. + +contents: 0-D. The PNG-encoded image. +channels: Number of color channels for the decoded image. +image: 3-D with shape `[height, width, channels]`. +)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("EncodePng") + .Input("image: uint8") + .Attr("compression: int = -1") + .Output("contents: string") + .Doc(R"doc( +PNG-encode an image. + +`image` is a 3-D uint8 Tensor of shape `[height, width, channels]` where +`channels` is: + +* 1: for grayscale. +* 3: for RGB. +* 4: for RGBA. + +The ZLIB compression level, `compression`, can be -1 for the PNG-encoder +default or a value from 0 to 9. 9 is the highest compression level, generating +the smallest output, but is slower. + +image: 3-D with shape `[height, width, channels]`. +compression: Compression level. +contents: 0-D. PNG-encoded image. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc new file mode 100644 index 0000000000..937fedd45d --- /dev/null +++ b/tensorflow/core/ops/io_ops.cc @@ -0,0 +1,332 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +REGISTER_OP("Save") + .Input("filename: string") + .Input("tensor_names: string") + .Input("data: T") + .Attr("T: list({float, double, int32, int64, quint8, qint8, qint32})") + .Doc(R"doc( +Saves the input tensors to disk. + +The size of `tensor_names` must match the number of tensors in `data`. `data[i]` +is written to `filename` with name `tensor_names[i]`. + +See also `SaveSlices`. + +filename: Must have a single element. The name of the file to which we write +the tensor. +tensor_names: Shape `[N]`. The names of the tensors to be saved. +data: `N` tensors to save. +)doc"); + +REGISTER_OP("SaveSlices") + .Input("filename: string") + .Input("tensor_names: string") + .Input("shapes_and_slices: string") + .Input("data: T") + .Attr("T: list({float, double, int32, int64, quint8, qint8, qint32})") + .Doc(R"doc( +Saves input tensors slices to disk. + +This is like `Save` except that tensors can be listed in the saved file as being +a slice of a larger tensor. `shapes_and_slices` specifies the shape of the +larger tensor and the slice that this tensor covers. `shapes_and_slices` must +have as many elements as `tensor_names`. + +Elements of the `shapes_and_slices` input must either be: + +* The empty string, in which case the corresponding tensor is + saved normally. +* A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the + `dimI` are the dimensions of the larger tensor and `slice-spec` + specifies what part is covered by the tensor to save. + +`slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1` +where each `sliceI` is either: + +* The string `-` meaning that the slice covers all indices of this dimension +* `start,length` where `start` and `length` are integers. In that + case the slice covers `length` indices starting at `start`. + +See also `Save`. + +filename: Must have a single element. The name of the file to which we write the +tensor. +tensor_names: Shape `[N]`. The names of the tensors to be saved. +shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when +saving the tensors. +data: `N` tensors to save. +)doc"); + +REGISTER_OP("Restore") + .Input("file_pattern: string") + .Input("tensor_name: string") + .Output("tensor: dt") + .Attr("dt: type") + .Attr("preferred_shard: int = -1") + .Doc(R"doc( +Restores a tensor from checkpoint files. + +Reads a tensor stored in one or several files. If there are several files (for +instance because a tensor was saved as slices), `file_pattern` may contain +wildcard symbols (`*` and `?`) in the filename portion only, not in the +directory portion. + +If a `file_pattern` matches several files, `preferred_shard` can be used to hint +in which file the requested tensor is likely to be found. This op will first +open the file at index `preferred_shard` in the list of matching files and try +to restore tensors from that file. Only if some tensors or tensor slices are +not found in that first file, then the Op opens all the files. Setting +`preferred_shard` to match the value passed as the `shard` input +of a matching `Save` Op may speed up Restore. This attribute only affects +performance, not correctness. The default value -1 means files are processed in +order. + +See also `RestoreSlice`. + +file_pattern: Must have a single element. The pattern of the files from + which we read the tensor. +tensor_name: Must have a single element. The name of the tensor to be + restored. +tensor: The restored tensor. +dt: The type of the tensor to be restored. +preferred_shard: Index of file to open first if multiple files match + `file_pattern`. +)doc"); + +REGISTER_OP("RestoreSlice") + .Input("file_pattern: string") + .Input("tensor_name: string") + .Input("shape_and_slice: string") + .Output("tensor: dt") + .Attr("dt: type") + .Attr("preferred_shard: int = -1") + .Doc(R"doc( +Restores a tensor from checkpoint files. + +This is like `Restore` except that restored tensor can be listed as filling +only a slice of a larger tensor. `shape_and_slice` specifies the shape of the +larger tensor and the slice that the restored tensor covers. + +The `shape_and_slice` input has the same format as the +elements of the `shapes_and_slices` input of the `SaveSlices` op. + +file_pattern: Must have a single element. The pattern of the files from + which we read the tensor. +tensor_name: Must have a single element. The name of the tensor to be + restored. +shape_and_slice: Scalar. The shapes and slice specifications to use when + restoring a tensors. +tensor: The restored tensor. +dt: The type of the tensor to be restored. +preferred_shard: Index of file to open first if multiple files match + `file_pattern`. See the documentation for `Restore`. +)doc"); + +REGISTER_OP("ShardedFilename") + .Input("basename: string") + .Input("shard: int32") + .Input("num_shards: int32") + .Output("filename: string") + .Doc(R"doc( +Generate a sharded filename. The filename is printf formated as + %s-%05d-of-%05d, basename, shard, num_shards. +)doc"); + +REGISTER_OP("ShardedFilespec") + .Input("basename: string") + .Input("num_shards: int32") + .Output("filename: string") + .Doc(R"doc( +Generate a glob pattern matching all sharded file names. +)doc"); + +// Reader source ops ---------------------------------------------------------- + +REGISTER_OP("WholeFileReader") + .Output("reader_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .Doc(R"doc( +A Reader that outputs the entire contents of a file as a value. + +To use, enqueue filenames in a Queue. The output of ReaderRead will +be a filename (key) and the contents of that file (value). + +reader_handle: The handle to reference the Reader. +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +)doc"); + +REGISTER_OP("TextLineReader") + .Output("reader_handle: Ref(string)") + .Attr("skip_header_lines: int = 0") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .Doc(R"doc( +A Reader that outputs the lines of a file delimited by '\n'. + +reader_handle: The handle to reference the Reader. +skip_header_lines: Number of lines to skip from the beginning of every file. +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +)doc"); + +REGISTER_OP("FixedLengthRecordReader") + .Output("reader_handle: Ref(string)") + .Attr("header_bytes: int = 0") + .Attr("record_bytes: int") + .Attr("footer_bytes: int = 0") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .Doc(R"doc( +A Reader that outputs fixed-length records from a file. + +reader_handle: The handle to reference the Reader. +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +)doc"); + +REGISTER_OP("TFRecordReader") + .Output("reader_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .Doc(R"doc( +A Reader that outputs the records from a TensorFlow Records file. + +reader_handle: The handle to reference the Reader. +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +)doc"); + +REGISTER_OP("IdentityReader") + .Output("reader_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .Doc(R"doc( +A Reader that outputs the queued work as both the key and value. + +To use, enqueue strings in a Queue. ReaderRead will take the front +work string and output (work, work). + +reader_handle: The handle to reference the Reader. +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +)doc"); + +// Ops that operate on Readers ------------------------------------------------ + +REGISTER_OP("ReaderRead") + .Input("reader_handle: Ref(string)") + .Input("queue_handle: Ref(string)") + .Output("key: string") + .Output("value: string") + .Doc(R"doc( +Returns the next record (key, value pair) produced by a Reader. + +Will dequeue from the input queue if necessary (e.g. when the +Reader needs to start reading from a new file since it has finished +with the previous file). + +reader_handle: Handle to a Reader. +queue_handle: Handle to a Queue, with string work items. +key: A scalar. +value: A scalar. +)doc"); + +REGISTER_OP("ReaderNumRecordsProduced") + .Input("reader_handle: Ref(string)") + .Output("records_produced: int64") + .Doc(R"doc( +Returns the number of records this Reader has produced. + +This is the same as the number of ReaderRead executions that have +succeeded. + +reader_handle: Handle to a Reader. +)doc"); + +REGISTER_OP("ReaderNumWorkUnitsCompleted") + .Input("reader_handle: Ref(string)") + .Output("units_completed: int64") + .Doc(R"doc( +Returns the number of work units this Reader has finished processing. + +reader_handle: Handle to a Reader. +)doc"); + +REGISTER_OP("ReaderSerializeState") + .Input("reader_handle: Ref(string)") + .Output("state: string") + .Doc(R"doc( +Produce a string tensor that encodes the state of a Reader. + +Not all Readers support being serialized, so this can produce an +Unimplemented error. + +reader_handle: Handle to a Reader. +)doc"); + +REGISTER_OP("ReaderRestoreState") + .Input("reader_handle: Ref(string)") + .Input("state: string") + .Doc(R"doc( +Restore a reader to a previously saved state. + +Not all Readers support being restored, so this can produce an +Unimplemented error. + +reader_handle: Handle to a Reader. +state: Result of a ReaderSerializeState of a Reader with type + matching reader_handle. +)doc"); + +REGISTER_OP("ReaderReset") + .Input("reader_handle: Ref(string)") + .Doc(R"doc( +Restore a Reader to its initial clean state. + +reader_handle: Handle to a Reader. +)doc"); + +// Other input Ops ---------------------------------------------------------- + +REGISTER_OP("ReadFile") + .Input("filename: string") + .Output("contents: string") + .Doc(R"doc( +Reads and outputs the entire contents of the input filename. +)doc"); + +REGISTER_OP("MatchingFiles") + .Input("pattern: string") + .Output("filenames: string") + .Doc(R"doc( +Returns the set of files matching a pattern. + +Note that this routine only supports wildcard characters in the +basename portion of the pattern, not in the directory portion. + +pattern: A (scalar) shell wildcard pattern. +filenames: A vector of matching filenames. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc new file mode 100644 index 0000000000..a9b940295e --- /dev/null +++ b/tensorflow/core/ops/linalg_ops.cc @@ -0,0 +1,97 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("MatrixDeterminant") + .Input("input: T") + .Output("output: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Calculates the determinant of a square matrix. + +input: A tensor of shape `[M, M]`. +output: A scalar, equal to the determinant of the input. +T: The type of values in the input and output. +)doc"); + +REGISTER_OP("BatchMatrixDeterminant") + .Input("input: T") + .Output("output: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Calculates the determinants for a batch of square matrices. + +The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +form square matrices. The output is a 1-D tensor containing the determinants +for all input submatrices `[..., :, :]`. + +input: Shape is `[..., M, M]`. +output: Shape is `[...]`. +T: The type of values in the input and output. +)doc"); + +REGISTER_OP("MatrixInverse") + .Input("input: T") + .Output("output: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Calculates the inverse of a square invertible matrix. Checks for invertibility. + +input: Shape is `[M, M]`. +output: Shape is `[M, M]` containing the matrix inverse of the input. +T: The type of values in the input and output. +)doc"); + +REGISTER_OP("BatchMatrixInverse") + .Input("input: T") + .Output("output: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Calculates the inverse of square invertible matrices. Checks for invertibility. + +The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +form square matrices. The output is a tensor of the same shape as the input +containing the inverse for all input submatrices `[..., :, :]`. + +input: Shape is `[..., M, M]`. +output: Shape is `[..., M, M]`. +T: The type of values in the input and output. +)doc"); + +REGISTER_OP("Cholesky") + .Input("input: T") + .Output("output: T") + .Attr("T: {double, float}") + .Doc(R"doc( +Calculates the Cholesky decomposition of a square matrix. + +The input has to be symmetric and positive definite. Only the lower-triangular +part of the input will be used for this operation. The upper-triangular part +will not be read. + +The result is the lower-triangular matrix of the Cholesky decomposition of the +input. + +input: Shape is `[M, M]`. +output: Shape is `[M, M]`. +T: The type of values in the input and output. +)doc"); + +REGISTER_OP("BatchCholesky") + .Input("input: T") + .Output("output: T") + .Attr("T: {double, float}") + .Doc(R"doc( +Calculates the Cholesky decomposition of a batch of square matrices. + +The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +form square matrices, with the same constraints as the single matrix Cholesky +decomposition above. The output is a tensor of the same shape as the input +containing the Cholesky decompositions for all input submatrices `[..., :, :]`. + +input: Shape is `[..., M, M]`. +output: Shape is `[..., M, M]`. +T: The type of values in the input and output. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc new file mode 100644 index 0000000000..28546fe645 --- /dev/null +++ b/tensorflow/core/ops/logging_ops.cc @@ -0,0 +1,43 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("Assert") + .Input("condition: bool") + .Input("data: T") + .Attr("T: list(type)") + .Attr("summarize: int = 3") + .Doc(R"doc( +Asserts that the given condition is true. + +If `condition` evaluates to false, print the list of tensors in `data`. +`summarize` determines how many entries of the tensors to print. + +condition: The condition to evaluate. +data: The tensors to print out when condition is false. +summarize: Print this many entries of each tensor. +)doc"); + +REGISTER_OP("Print") + .Input("input: T") + .Input("data: U") + .Output("output: T") + .Attr("T: type") + .Attr("U: list(type)") + .Attr("message: string = ''") + .Attr("first_n: int = -1") + .Attr("summarize: int = 3") + .Doc(R"doc( +Prints a list of tensors. + +Passes `input` through to `output` and prints `data` when evaluating. + +input: The tensor passed to `output` +data: A list of tensors to print out when op is evaluated. +output:= The unmodified `input` tensor +message: A string, prefix of the error message. +first_n: Only log `first_n` number of times. -1 disables logging. +summarize: Only print this many entries of each tensor. +)doc"); + +} // end namespace tensorflow diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc new file mode 100644 index 0000000000..20e56316ea --- /dev/null +++ b/tensorflow/core/ops/math_ops.cc @@ -0,0 +1,1053 @@ +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("AddN") + .Input("inputs: N * T") + .Output("sum: T") + .Attr("N: int >= 1") + .Attr("T: numbertype") + .SetIsCommutative() + .SetIsAggregate() + .Doc(R"doc( +Add all input tensors element wise. + +inputs: Must all be the same size and shape. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("BatchMatMul") + .Input("x: T") + .Input("y: T") + .Output("out: T") + .Attr("T: {float, double, int32, complex64}") + .Attr("adj_x: bool = false") + .Attr("adj_y: bool = false") + .Doc(R"doc( +Multiplies slices of two tensors in batches. + +Multiplies all slices of `Tensor` `x` and `y` (each slice can be +viewed as an element of a batch), and arranges the individual results +in a single output tensor of the same batch size. Each of the +individual slices can optionally be adjointed (to adjoint a matrix +means to transpose and conjugate it) before multiplication by setting +the `adj_x` or `adj_y` flag to `True`, which are by default `False`. + +The input tensors `x` and `y` are 3-D or higher with shape `[..., r_x, c_x]` +and `[..., r_y, c_y]`. + +The output tensor is 3-D or higher with shape `[..., r_o, c_o]`, where: + + r_o = c_x if adj_x else r_x + c_o = r_y if adj_y else c_y + +It is computed as: + + out[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) + +x: 3-D or higher with shape `[..., r_x, c_x]`. +y: 3-D or higher with shape `[..., r_y, c_y]`. +out: 3-D or higher with shape `[..., r_o, c_o]` +adj_x: If `True`, adjoint the slices of `x`. Defaults to `False`. +adj_y: If `True`, adjoint the slices of `y`. Defaults to `False`. +)doc"); + +// -------------------------------------------------------------------------- +// Casting Ops +// +// NOTE: Only a smaller number of types are supported by +// Cast. The exact casting rule is TBD. The current +// implementation uses C++ static cast rules for numeric +// types, which may be changed in the future. +REGISTER_OP("Cast") + .Input("x: SrcT") + .Output("y: DstT") + .Attr("SrcT: type") + .Attr("DstT: type") + .Doc(R"doc( +Cast x of type SrcT to y of DstT. +)doc"); + +REGISTER_OP("_HostCast") + .Input("x: SrcT") + .Output("y: DstT") + .Attr("SrcT: type") + .Attr("DstT: type") + .Doc(R"doc( +Cast x of type SrcT to y of DstT. + +_HostCast requires its input and produces its output in host memory. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("Abs") + .Input("x: T") + .Output("y: T") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Computes the absolute value of a tensor. + +Given a tensor `x`, this operation returns a tensor containing the absolute +value of each element in `x`. For example, if x is an input element and y is +an output element, this operation computes \\(y = |x|\\). +)doc"); + +REGISTER_OP("ComplexAbs") + .Input("x: complex64") + .Output("y: float") + .Doc(R"doc( +Computes the complex absolute value of a tensor. + +Given a tensor `x` of complex numbers, this operation returns a tensor of type +`float` that is the absolute value of each element in `x`. All elements in `x` +must be complex numbers of the form \\(a + bj\\). The absolute value is +computed as \\( \sqrt{a^2 + b^2}\\). + +For example: + +``` +# tensor 'x' is [[-2.25 + 4.75j], [-3.25 + 5.75j]] +tf.complex_abs(x) ==> [5.25594902, 6.60492229] +``` +)doc"); + +// Declares cwise unary operations signature: 't -> 't +#define UNARY() \ + Input("x: T").Output("y: T").Attr( \ + "T: {float, double, int32, complex64, int64}") + +REGISTER_OP("Neg") + .UNARY() + .Doc(R"doc( +Computes numerical negative value element-wise. +I.e., \\(y = -x\\). +)doc"); + +REGISTER_OP("Inv") + .UNARY() + .Doc(R"doc( +Computes the reciprocal of x element-wise. +I.e., \\(y = 1 / x\\). +)doc"); + +REGISTER_OP("Square") + .UNARY() + .Doc(R"doc( +Computes square of x element-wise. +I.e., \\(y = x * x = x^2\\). +)doc"); + +REGISTER_OP("Sqrt") + .UNARY() + .Doc(R"doc( +Computes square root of x element-wise. +I.e., \\(y = \sqrt{x} = x^{1/2}\\). +)doc"); + +REGISTER_OP("Rsqrt") + .UNARY() + .Doc(R"doc( +Computes reciprocal of square root of x element-wise. +I.e., \\(y = 1 / \sqrt{x}\\). +)doc"); + +REGISTER_OP("Exp") + .UNARY() + .Doc(R"doc( +Computes exponential of x element-wise. \\(y = e^x\\). +)doc"); + +REGISTER_OP("Log") + .UNARY() + .Doc(R"doc( +Computes natural logrithm of x element-wise. +I.e., \\(y = \log_e x\\). +)doc"); + +REGISTER_OP("Tanh") + .UNARY() + .Doc(R"doc( +Computes hyperbolic tangent of `x` element-wise. +)doc"); + +REGISTER_OP("Sigmoid") + .UNARY() + .Doc(R"doc( +Computes sigmoid of `x` element-wise. + +Specifically, `y = 1 / (1 + exp(-x))`. +)doc"); + +REGISTER_OP("Sin") + .UNARY() + .Doc(R"doc( +Computes sin of x element-wise. +)doc"); + +REGISTER_OP("Cos") + .UNARY() + .Doc(R"doc( +Computes cos of x element-wise. +)doc"); + +#undef UNARY + +REGISTER_OP("IsNan") + .Input("x: T") + .Output("y: bool") + .Attr("T: {float, double}") + .Doc(R"doc( +Returns which elements of x are NaN. +)doc"); + +REGISTER_OP("IsInf") + .Input("x: T") + .Output("y: bool") + .Attr("T: {float, double}") + .Doc(R"doc( +Returns which elements of x are Inf. +)doc"); + +REGISTER_OP("IsFinite") + .Input("x: T") + .Output("y: bool") + .Attr("T: {float, double}") + .Doc(R"doc( +Returns which elements of x are finite. +)doc"); + +REGISTER_OP("Sign") + .Input("x: T") + .Output("y: T") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Returns an element-wise indication of the sign of a number. + +y = sign(x) = -1 if x < 0; 0 if x == 0; 1 if x > 0. +)doc"); + +REGISTER_OP("Floor") + .Input("x: T") + .Output("y: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Returns element-wise largest integer not greater than x. +)doc"); + +REGISTER_OP("Ceil") + .Input("x: T") + .Output("y: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Returns element-wise smallest integer in not less than x. +)doc"); + +// Declares cwise binary operations signature: 't, 't -> 't. + +#define BINARY_MORE() \ + Input("x: T").Input("y: T").Output("z: T").Attr( \ + "T: {float, double, int8, int16, int32, complex64, int64}") + +#define BINARY_FEWER() \ + Input("x: T").Input("y: T").Output("z: T").Attr( \ + "T: {float, double, int32, complex64, int64}") + +REGISTER_OP("Add") + .BINARY_MORE() + .SetIsCommutative() + .Doc(R"doc( +Returns x + y element-wise. + +*NOTE*: Add supports broadcasting. AddN does not. +)doc"); + +REGISTER_OP("Sub") + .BINARY_FEWER() + .Doc(R"doc( +Returns x - y element-wise. +)doc"); + +REGISTER_OP("Mul") + .BINARY_MORE() + .SetIsCommutative() + .Doc(R"doc( +Returns x * y element-wise. +)doc"); + +REGISTER_OP("Div") + .BINARY_FEWER() + .Doc(R"doc( +Returns x / y element-wise. +)doc"); + +#undef BINARY_FEWER +#undef BINARY_MORE + +REGISTER_OP("Maximum") + .Input("x: T") + .Input("y: T") + .Output("z: T") + .Attr("T: {float, double, int32, int64}") + .SetIsCommutative() + .Doc(R"doc( +Returns the max of x and y (i.e. x > y ? x : y) element-wise, broadcasts. +)doc"); + +REGISTER_OP("Minimum") + .Input("x: T") + .Input("y: T") + .Output("z: T") + .Attr("T: {float, double, int32, int64}") + .SetIsCommutative() + .Doc(R"doc( +Returns the min of x and y (i.e. x < y ? x : y) element-wise, broadcasts. +)doc"); + +REGISTER_OP("Mod") + .Input("x: T") + .Input("y: T") + .Output("z: T") + .Attr("T: {int32, int64, float, double}") + .Doc(R"doc( +Returns element-wise remainder of division. +)doc"); + +REGISTER_OP("Pow") + .Input("x: T") + .Input("y: T") + .Output("z: T") + .Attr("T: {float, double, int32, complex64, int64}") + .Doc(R"doc( +Computes the power of one value to another. + +Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for +corresponding elements in `x` and `y`. For example: + +``` +# tensor 'x' is [[2, 2]], [3, 3]] +# tensor 'y' is [[8, 16], [2, 3]] +tf.pow(x, y) ==> [[256, 65536], [9, 27]] +``` +)doc"); + +// -------------------------------------------------------------------------- + +// Declares cwise binary comparison operations signature: 't, 't -> bool, +// where 't has a natural total order. +#define COMPARISON() \ + Input("x: T").Input("y: T").Output("z: bool").Attr( \ + "T: {float, double, int32, int64}") + +REGISTER_OP("Less") + .COMPARISON() + .Doc(R"doc( +Returns the truth value of (x < y) element-wise. +)doc"); + +REGISTER_OP("LessEqual") + .COMPARISON() + .Doc(R"doc( +Returns the truth value of (x <= y) element-wise. +)doc"); + +REGISTER_OP("Greater") + .COMPARISON() + .Doc(R"doc( +Returns the truth value of (x > y) element-wise. +)doc"); + +REGISTER_OP("GreaterEqual") + .COMPARISON() + .Doc(R"doc( +Returns the truth value of (x >= y) element-wise. +)doc"); + +#undef COMPARISON + +// -------------------------------------------------------------------------- + +#define COMPARISON() \ + Input("x: T").Input("y: T").Output("z: bool").SetIsCommutative().Attr( \ + "T: {float, double, int32, int64, complex64, quint8, qint8, qint32}") + +REGISTER_OP("Equal") + .COMPARISON() + .Doc(R"doc( +Returns the truth value of (x == y) element-wise. +)doc"); + +REGISTER_OP("NotEqual") + .COMPARISON() + .Doc(R"doc( +Returns the truth value of (x != y) element-wise. +)doc"); + +#undef COMPARISON + +// -------------------------------------------------------------------------- + +REGISTER_OP("LogicalNot") + .Input("x: bool") + .Output("y: bool") + .Doc(R"doc( +Returns the truth value of NOT x element-wise. +)doc"); + +#define BINARY_LOGICAL() \ + Input("x: bool").Input("y: bool").Output("z: bool").SetIsCommutative() + +REGISTER_OP("LogicalAnd") + .BINARY_LOGICAL() + .Doc(R"doc( +Returns the truth value of x AND y element-wise. +)doc"); + +REGISTER_OP("LogicalOr") + .BINARY_LOGICAL() + .Doc(R"doc( +Returns the truth value of x OR y element-wise. +)doc"); + +#undef BINARY_LOGICAL + +// -------------------------------------------------------------------------- + +REGISTER_OP("Select") + .Input("condition: bool") + .Input("t: T") + .Input("e: T") + .Output("out: T") + .Attr("T: type") + .Doc(R"doc( +Selects elements from `t` or `e`, depending on `condition`. + +The `condition`, `t`, and `e` tensors must all have the same shape, +and the output will also have that shape. The `condition` tensor acts +as an element-wise mask that chooses, based on the value at each +element, whether the corresponding element in the output should be +taken from `t` (if true) or `e` (if false). For example: + +For example: + +```prettyprint +# 'condition' tensor is [[True, False] +# [True, False]] +# 't' is [[1, 1], +# [1, 1]] +# 'e' is [[2, 2], +# [2, 2]] +select(condition, t, e) ==> [[1, 2], + [1, 2]] +``` + +t:= A `Tensor` with the same shape as `condition`. +e:= A `Tensor` with the same type and shape as `t`. +out:= A `Tensor` with the same type and shape as `t` and `e`. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("MatMul") + .Input("a: T") + .Input("b: T") + .Output("product: T") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .Attr("T: {float, double, int32, complex64}") + .Doc(R"doc( +Multiply the matrix "a" by the matrix "b". + +The inputs must be two-dimensional matrices and the inner dimension of +"a" (after being transposed if transpose_a is true) must match the +outer dimension of "b" (after being transposed if transposed_b is +true). + +*Note*: The default kernel implementation for MatMul on GPUs uses +cublas. + +transpose_a: If true, "a" is transposed before multiplication. +transpose_b: If true, "b" is transposed before multiplication. +)doc"); + +REGISTER_OP("SparseMatMul") + .Input("a: float") + .Input("b: float") + .Output("product: float") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .Attr("a_is_sparse: bool = false") + .Attr("b_is_sparse: bool = false") + .Doc(R"doc( +Multiply matrix "a" by matrix "b". + +The inputs must be two-dimensional matrices and the inner dimension of "a" must +match the outer dimension of "b". This op is optimized for the case where at +least one of "a" or "b" is sparse. The breakeven for using this versus a dense +matrix multiply on one platform was 30% zero values in the sparse matrix. +)doc"); + +// -------------------------------------------------------------------------- + +// For operations where the output is a reduction function along some +// dimensions of the input. +REGISTER_OP("Sum") + .Input("input: T") + .Input("reduction_indices: int32") + .Output("output: T") + .Attr("keep_dims: bool = false") + .Attr("T: numbertype") + .Doc(R"doc( +Computes the sum of elements across dimensions of a tensor. + +Reduces `input` along the dimensions given in `reduction_indices`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`reduction_indices`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + +input: The tensor to reduce. +reduction_indices: The dimensions to reduce. +keep_dims: If true, retain reduced dimensions with length 1. +output: The reduced tensor. +)doc"); + +REGISTER_OP("Mean") + .Input("input: T") + .Input("reduction_indices: int32") + .Output("output: T") + .Attr("keep_dims: bool = false") + .Attr("T: numbertype") + .Doc(R"doc( +Computes the mean of elements across dimensions of a tensor. + +Reduces `input` along the dimensions given in `reduction_indices`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`reduction_indices`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + +input: The tensor to reduce. +reduction_indices: The dimensions to reduce. +keep_dims: If true, retain reduced dimensions with length 1. +output: The reduced tensor. +)doc"); + +REGISTER_OP("Prod") + .Input("input: T") + .Input("reduction_indices: int32") + .Output("output: T") + .Attr("keep_dims: bool = false") + .Attr("T: numbertype") + .Doc(R"doc( +Computes the product of elements across dimensions of a tensor. + +Reduces `input` along the dimensions given in `reduction_indices`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`reduction_indices`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + +input: The tensor to reduce. +reduction_indices: The dimensions to reduce. +keep_dims: If true, retain reduced dimensions with length 1. +output: The reduced tensor. +)doc"); + +REGISTER_OP("Min") + .Input("input: T") + .Input("reduction_indices: int32") + .Output("output: T") + .Attr("keep_dims: bool = false") + .Attr("T: numbertype") + .Doc(R"doc( +Computes the minimum of elements across dimensions of a tensor. + +Reduces `input` along the dimensions given in `reduction_indices`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`reduction_indices`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + +input: The tensor to reduce. +reduction_indices: The dimensions to reduce. +keep_dims: If true, retain reduced dimensions with length 1. +output: The reduced tensor. +)doc"); + +REGISTER_OP("Max") + .Input("input: T") + .Input("reduction_indices: int32") + .Output("output: T") + .Attr("keep_dims: bool = false") + .Attr("T: numbertype") + .Doc(R"doc( +Computes the maximum of elements across dimensions of a tensor. + +Reduces `input` along the dimensions given in `reduction_indices`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`reduction_indices`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + +input: The tensor to reduce. +reduction_indices: The dimensions to reduce. +keep_dims: If true, retain reduced dimensions with length 1. +output: The reduced tensor. +)doc"); + +REGISTER_OP("ArgMax") + .Input("input: T") + .Input("dimension: int32") + .Output("output: int64") + .Attr("T: numbertype") + .Doc(R"doc( +Returns the index with the largest value across dimensions of a tensor. + +dimension: int32, 0 <= dimension < rank(input). Describes which dimension + of the input Tensor to reduce across. For vectors, use dimension = 0. +)doc"); + +REGISTER_OP("ArgMin") + .Input("input: T") + .Input("dimension: int32") + .Output("output: int64") + .Attr("T: numbertype") + .Doc(R"doc( +Returns the index with the smallest value across dimensions of a tensor. + +dimension: int32, 0 <= dimension < rank(input). Describes which dimension + of the input Tensor to reduce across. For vectors, use dimension = 0. +)doc"); + +REGISTER_OP("SegmentSum") + .Input("data: T") + .Input("segment_ids: Tindices") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .Doc(R"doc( +Computes the sum along segments of a tensor. + +Read [the section on Segmentation](../python/math_ops.md#segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \sum_j data_j\\) where sum is over `j` such +that `segment_ids[j] == i`. + +
+ +
+ +segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +first dimension. Values should be sorted and can be repeated. + +output: Has same shape as data, except for dimension_0 which +has size `k`, the number of segments. +)doc"); + +REGISTER_OP("SegmentMean") + .Input("data: T") + .Input("segment_ids: Tindices") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .Doc(R"doc( +Computes the mean along segments of a tensor. + +Read [the section on Segmentation](../python/math_ops.md#segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is +over `j` such that `segment_ids[j] == i` and `N` is the total number of +values summed. + +
+ +
+ +segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +first dimension. Values should be sorted and can be repeated. + +output: Has same shape as data, except for dimension_0 which +has size `k`, the number of segments. +)doc"); + +REGISTER_OP("SegmentProd") + .Input("data: T") + .Input("segment_ids: Tindices") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .Doc(R"doc( +Computes the product along segments of a tensor. + +Read [the section on Segmentation](../python/math_ops.md#segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \prod_j data_j\\) where the product is over `j` such +that `segment_ids[j] == i`. + +
+ +
+ +segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +first dimension. Values should be sorted and can be repeated. + +output: Has same shape as data, except for dimension_0 which +has size `k`, the number of segments. +)doc"); + +REGISTER_OP("SegmentMin") + .Input("data: T") + .Input("segment_ids: Tindices") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .Doc(R"doc( +Computes the minimum along segments of a tensor. + +Read [the section on Segmentation](../python/math_ops.md#segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \min_j(data_j)\\) where `min` is over `j` such +that `segment_ids[j] == i`. + +
+ +
+ +segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +first dimension. Values should be sorted and can be repeated. + +output: Has same shape as data, except for dimension_0 which +has size `k`, the number of segments. +)doc"); + +REGISTER_OP("SegmentMax") + .Input("data: T") + .Input("segment_ids: Tindices") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .Doc(R"doc( +Computes the maximum along segments of a tensor. + +Read [the section on Segmentation](../python/math_ops.md#segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \max_j(data_j)\\) where `max` is over `j` such +that `segment_ids[j] == i`. + +
+ +
+ +segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +first dimension. Values should be sorted and can be repeated. + +output: Has same shape as data, except for dimension_0 which +has size `k`, the number of segments. +)doc"); + +REGISTER_OP("UnsortedSegmentSum") + .Input("data: T") + .Input("segment_ids: Tindices") + .Input("num_segments: int32") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .Doc(R"doc( +Computes the sum along segments of a tensor. + +Read [the section on Segmentation](../python/math_ops.md#segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \sum_j data_j\\) where sum is over `j` such +that `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids` +need not be sorted and need not cover all values in the full + range of valid values. + +If the sum is empty for a given segment ID `i`, `output[i] = 0`. + +`num_segments` should equal the number of distinct segment IDs. + +
+ +
+ +segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +first dimension. + +output: Has same shape as data, except for dimension_0 which +has size `num_segments`. + +)doc"); + +REGISTER_OP("SparseSegmentSum") + .Input("data: T") + .Input("indices: int32") + .Input("segment_ids: int32") + .Output("output: T") + .Attr("T: realnumbertype") + .Doc(R"doc( +Computes the sum along sparse segments of a tensor. + +Read [the section on Segmentation](../python/math_ops.md#segmentation) +for an explanation of segments. + +Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first +dimension, selecting a subset of dimension_0, specified by `indices`. + +For example: + +```prettyprint +c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) + +# Select two rows, one segment. +tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0])) + ==> [[0 0 0 0]] + +# Select two rows, two segment. +tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1])) + ==> [[ 1 2 3 4] + [-1 -2 -3 -4]] + +# Select all rows, two segments. +tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1])) + ==> [[0 0 0 0] + [5 6 7 8]] + +# Which is equivalent to: +tf.segment_sum(c, tf.constant([0, 0, 1])) +``` + +indices: A 1-D tensor. Has same rank as `segment_ids`. + +segment_ids: A 1-D tensor. Values should be sorted and can be repeated. + +output: Has same shape as data, except for dimension_0 which +has size `k`, the number of segments. +)doc"); + +REGISTER_OP("SparseSegmentMean") + .Input("data: T") + .Input("indices: int32") + .Input("segment_ids: int32") + .Output("output: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Computes the mean along sparse segments of a tensor. + +Read [the section on Segmentation](../python/math_ops.md#segmentation) +for an explanation of segments. + +Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first +dimension, selecting a subset of dimension_0, specified by `indices`. + +indices: A 1-D tensor. Has same rank as `segment_ids`. + +segment_ids: A 1-D tensor. Values should be sorted and can be repeated. + +output: Has same shape as data, except for dimension_0 which +has size `k`, the number of segments. + +)doc"); + +REGISTER_OP("SparseSegmentMeanGrad") + .Input("grad: T") + .Input("indices: int32") + .Input("segment_ids: int32") + .Input("output_dim0: int32") + .Output("output: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Computes gradients for SparseSegmentMean. + +Returns tensor "output" with same shape as grad, except for dimension_0 whose +value is output_dim0. + +grad: gradient propagated to the SparseSegmentMean op. +indices: indices passed to the corresponding SparseSegmentMean op. +segment_ids: segment_ids passed to the corresponding SparseSegmentMean op. +output_dim0: dimension_0 of "data" passed to SparseSegmentMean op. +)doc"); + +REGISTER_OP("All") + .Input("input: bool") + .Input("reduction_indices: int32") + .Output("output: bool") + .Attr("keep_dims: bool = false") + .Doc(R"doc( +Computes the "logical and" of elements across dimensions of a tensor. + +Reduces `input` along the dimensions given in `reduction_indices`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`reduction_indices`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + +input: The tensor to reduce. +reduction_indices: The dimensions to reduce. +keep_dims: If true, retain reduced dimensions with length 1. +output: The reduced tensor. +)doc"); + +REGISTER_OP("Any") + .Input("input: bool") + .Input("reduction_indices: int32") + .Attr("keep_dims: bool = false") + .Output("output: bool") + .Doc(R"doc( +Computes the "logical or" of elements across dimensions of a tensor. + +Reduces `input` along the dimensions given in `reduction_indices`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`reduction_indices`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + +input: The tensor to reduce. +reduction_indices: The dimensions to reduce. +keep_dims: If true, retain reduced dimensions with length 1. +output: The reduced tensor. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("Range") + .Input("start: int32") + .Input("limit: int32") + .Input("delta: int32") + .Output("output: int32") + .Doc(R"doc( +Creates a sequence of integers. + +This operation creates a sequence of integers that begins at `start` and +extends by increments of `delta` up to but not including `limit`. + +For example: + +``` +# 'start' is 3 +# 'limit' is 18 +# 'delta' is 3 +tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] +``` + +start: 0-D (scalar). First entry in the sequence. +limit: 0-D (scalar). Upper limit of sequence, exclusive. +delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`. +output: 1-D. +)doc"); + +REGISTER_OP("LinSpace") + .Input("start: T") + .Input("stop: T") + .Input("num: int32") + .Output("output: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Generates values in an interval. + +A sequence of `num` evenly-spaced values are generated beginning at `start`. +If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, +so that the last one is exactly `stop`. + +For example: + +``` +tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] +``` + +start: First entry in the range. +stop: Last entry in the range. +num: Number of values to generate. +output: 1-D. The generated values. +)doc"); + +REGISTER_OP("Complex") + .Input("real: float") + .Input("imag: float") + .Output("out: complex64") + .Doc(R"doc( +Converts two real numbers to a complex number. + +Given a tensor `real` representing the real part of a complex number, and a +tensor `imag` representing the imaginary part of a complex number, this +operation returns complex numbers elementwise of the form \\(a + bj\\), where +*a* represents the `real` part and *b* represents the `imag` part. + +The input tensors `real` and `imag` must have the same shape. + +For example: + +``` +# tensor 'real' is [2.25, 3.25] +# tensor `imag` is [4.75, 5.75] +tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] +``` +)doc"); + +REGISTER_OP("Real") + .Input("in: complex64") + .Output("out: float") + .Doc(R"doc( +Returns the real part of a complex number. + +Given a tensor `in` of complex numbers, this operation returns a tensor of type +`float` that is the real part of each element in `in`. All elements in `in` +must be complex numbers of the form \\(a + bj\\), where *a* is the real part +returned by this operation and *b* is the imaginary part. + +For example: + +``` +# tensor 'in' is [-2.25 + 4.75j, 3.25 + 5.75j] +tf.real(in) ==> [-2.25, 3.25] +``` +)doc"); + +REGISTER_OP("Imag") + .Input("in: complex64") + .Output("out: float") + .Doc(R"doc( +Returns the imaginary part of a complex number. + +Given a tensor `in` of complex numbers, this operation returns a tensor of type +`float` that is the imaginary part of each element in `in`. All elements in `in` +must be complex numbers of the form \\(a + bj\\), where *a* is the real part +and *b* is the imaginary part returned by this operation. + +For example: + +``` +# tensor 'in' is [-2.25 + 4.75j, 3.25 + 5.75j] +tf.imag(in) ==> [4.75, 5.75] +``` +)doc"); + +REGISTER_OP("Conj") + .Input("in: complex64") + .Output("out: complex64") + .Doc(R"doc( +Returns the complex conjugate of a complex number. + +Given a tensor `in` of complex numbers, this operation returns a tensor of +complex numbers that are the complex conjugate of each element in `in`. The +complex numbers in `in` must be of the form \\(a + bj\\), where *a* is the real +part and *b* is the imaginary part. + +The complex conjugate returned by this operation is of the form \\(a - bj\\). + +For example: + +``` +# tensor 'in' is [-2.25 + 4.75j, 3.25 + 5.75j] +tf.conj(in) ==> [-2.25 - 4.75j, 3.25 - 5.75j] +``` +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc new file mode 100644 index 0000000000..03ba49d5cd --- /dev/null +++ b/tensorflow/core/ops/nn_ops.cc @@ -0,0 +1,543 @@ +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/util/padding.h" +namespace tensorflow { + +// -------------------------------------------------------------------------- + +REGISTER_OP("AvgPool") + .Input("value: T") + .Output("output: T") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Attr("T: {float, double}") + .Doc(R"doc( +Performs average pooling on the input. + +Each entry in `output` is the mean of the corresponding size `ksize` +window in `value`. + +value: 4-D with shape `[batch, height, width, channels]`. +ksize: The size of the sliding window for each dimension of `value`. +strides: The stride of the sliding window for each dimension of `value`. +padding: The type of padding algorithm to use. +output: The average pooled output tensor. +)doc"); + +REGISTER_OP("AvgPoolGrad") + .Input("orig_input_shape: int32") + .Input("grad: T") + .Output("output: T") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Attr("T: {float, double}") + .Doc(R"doc( +Computes gradients of the average pooling function. + +orig_input_shape: 1-D. Shape of the original input to `avg_pool`. +grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. + the output of `avg_pool`. +ksize: The size of the sliding window for each dimension of the input. +strides: The stride of the sliding window for each dimension of the input. +padding: The type of padding algorithm to use. +output: 4-D. Gradients w.r.t. the input of `avg_pool`. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("BatchNormWithGlobalNormalization") + .Input("t: T") + .Input("m: T") + .Input("v: T") + .Input("beta: T") + .Input("gamma: T") + .Output("result: T") + .Attr("T: numbertype") + .Attr("variance_epsilon: float") + .Attr("scale_after_normalization: bool") + .Doc(R"doc( +Batch normalization. + +t: A 4D input Tensor. +m: A 1D mean Tensor with size matching the last dimension of t. + This is the first output from MovingMoments. +v: A 1D variance Tensor with size matching the last dimension of t. + This is the second output from MovingMoments. +beta: A 1D beta Tensor with size matching the last dimension of t. + An offset to be added to the normalized tensor. +gamma: A 1D gamma Tensor with size matching the last dimension of t. + If "scale_after_normalization" is true, this tensor will be multiplied + with the normalized tensor. +variance_epsilon: A small float number to avoid dividing by 0. +scale_after_normalization: A bool indicating whether the resulted tensor + needs to be multiplied with gamma. +)doc"); + +REGISTER_OP("BatchNormWithGlobalNormalizationGrad") + .Input("t: T") + .Input("m: T") + .Input("v: T") + .Input("gamma: T") + .Input("backprop: T") + .Output("dx: T") + .Output("dm: T") + .Output("dv: T") + .Output("db: T") + .Output("dg: T") + .Attr("T: numbertype") + .Attr("variance_epsilon: float") + .Attr("scale_after_normalization: bool") + .Doc(R"doc( +Gradients for batch normalization. + +t: A 4D input Tensor. +m: A 1D mean Tensor with size matching the last dimension of t. + This is the first output from MovingMoments. +v: A 1D variance Tensor with size matching the last dimension of t. + This is the second output from MovingMoments. +gamma: A 1D gamma Tensor with size matching the last dimension of t. + If "scale_after_normalization" is true, this Tensor will be multiplied + with the normalized Tensor. +backprop: 4D backprop Tensor. +variance_epsilon: A small float number to avoid dividing by 0. +scale_after_normalization: A bool indicating whether the resulted tensor + needs to be multiplied with gamma. + +dx: 4D backprop tensor for input. +dm: 1D backprop tensor for mean. +dv: 1D backprop tensor for variance. +db: 1D backprop tensor for beta. +dg: 1D backprop tensor for gamma. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("BiasAdd") + .Attr("T: numbertype") + .Input("value: T") + .Input("bias: T") + .Output("output: T") + .Doc(R"doc( +Adds `bias` to `value`. + +This is a special case of `tf.add` where `bias` is restricted to be 1-D. +Broadcasting is supported, so `value` may have any number of dimensions. + +value: Any number of dimensions. +bias: 1-D with size the last dimension of `value`. +output: Broadcasted sum of `value` and `bias`. +)doc"); +// -------------------------------------------------------------------------- + +REGISTER_OP("Conv2D") + .Input("input: T") + .Input("filter: T") + .Output("output: T") + .Attr("T: {float, double}") + .Attr("strides: list(int)") + .Attr("use_cudnn_on_gpu: bool = true") + .Attr(GetPaddingAttrString()) + .Doc(R"doc( +Computes a 2-D convolution given 4-D `input` and `filter` tensors. + +Given an input tensor of shape `[batch, in_height, in_width, in_channels]` +and a filter / kernel tensor of shape +`[filter_height, filter_width, in_channels, out_channels]`, this op +performs the following: + +1. Flattens the filter to a 2-D matrix with shape + `[filter_height * filter_width * in_channels, output_channels]`. +2. Extracts image patches from the the input tensor to form a *virtual* + tensor of shape `[batch, out_height, out_width, + filter_height * filter_width * in_channels]`. +3. For each patch, right-multiplies the filter matrix and the image patch + vector. + +In detail, + + output[b, i, j, k] = + sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * + filter[di, dj, q, k] + +Must have `strides[0] = strides[3] = 1`. For the most common case of the same +horizontal and vertices strides, `strides = [1, stride, stride, 1]`. + +strides: 1-D of length 4. The stride of the sliding window for each dimension + of `input`. +padding: The type of padding algorithm to use. +)doc"); + +REGISTER_OP("Conv2DBackpropInput") + .Input("input_sizes: int32") + .Input("filter: T") + .Input("out_backprop: T") + .Output("output: T") + .Attr("T: {float, double}") + .Attr("strides: list(int)") + .Attr("use_cudnn_on_gpu: bool = true") + .Attr(GetPaddingAttrString()) + .Doc(R"doc( +Computes the gradients of convolution with respect to the input. + +input_sizes: An integer vector representing the shape of `input`, + where `input` is a 4-D `[batch, height, width, channels]` tensor. +filter: 4-D with shape + `[filter_height, filter_width, in_channels, out_channels]`. +out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. + Gradients w.r.t. the output of the convolution. +strides: The stride of the sliding window for each dimension of the input + of the convolution. +padding: The type of padding algorithm to use. +output: 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient + w.r.t. the input of the convolution. +)doc"); + +// TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a +// more general string attribute ('kernel_impl'?) that can be used to +// select among several possible implementations. +REGISTER_OP("Conv2DBackpropFilter") + .Input("input: T") + .Input("filter_sizes: int32") + .Output("output: T") + .Input("out_backprop: T") + .Attr("T: {float, double}") + .Attr("strides: list(int)") + .Attr("use_cudnn_on_gpu: bool = true") + .Attr(GetPaddingAttrString()) + .Doc(R"doc( +Computes the gradients of convolution with respect to the filter. + +input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +filter_sizes: An integer vector representing the tensor shape of `filter`, + where `filter` is a 4-D + `[filter_height, filter_width, in_channels, out_channels]` tensor. +out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. + Gradients w.r.t. the output of the convolution. +strides: The stride of the sliding window for each dimension of the input + of the convolution. +padding: The type of padding algorithm to use. +output: 4-D with shape + `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. + the `filter` input of the convolution. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("L2Loss") + .Input("t: T") + .Output("output: T") + .Attr("T: numbertype") + .Doc(R"doc( +L2 Loss. + +Computes half the L2 norm of a tensor without the `sqrt`: + + output = sum(t ** 2) / 2 + +t: Typically 2-D, but may have any dimensions. +output: 0-D. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("LRN") + .Input("input: float") + .Output("output: float") + .Attr("depth_radius: int = 5") + .Attr("bias: float = 1.0") + .Attr("alpha: float = 1.0") + .Attr("beta: float = 0.5") + .Doc(R"doc( +Local Response Normalization. + +The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last +dimension), and each vector is normalized independently. Within a given vector, +each component is divided by the weighted, squared sum of inputs within +`depth_radius`. In detail, + + sqr_sum[a, b, c, d] = + sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) + output = input / (bias + alpha * sqr_sum ** beta) + +For details, see [Krizhevsky et al., ImageNet classification with deep +convolutional neural networks (NIPS 2012)] +(http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). + +input: 4-D. +depth_radius: 0-D. Half-width of the 1-D normalization window. +bias: An offset (usually positive to avoid dividing by 0). +alpha: A scale factor, usually positive. +beta: An exponent. +)doc"); + +REGISTER_OP("LRNGrad") + .Input("input_grads: float") + .Input("input_image: float") + .Input("output_image: float") + .Output("output: float") + .Attr("depth_radius: int = 5") + .Attr("bias: float = 1.0") + .Attr("alpha: float = 1.0") + .Attr("beta: float = 0.5") + .Doc(R"doc( +Gradients for Local Response Normalization. + +input_grads: 4-D with shape `[batch, height, width, channels]`. +input_image: 4-D with shape `[batch, height, width, channels]`. +output_image: 4-D with shape `[batch, height, width, channels]`. +depth_radius: A depth radius. +bias: An offset (usually > 0 to avoid dividing by 0). +alpha: A scale factor, usually positive. +beta: An exponent. +output: The gradients for LRN. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("MaxPool") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Input("input: float") + .Output("output: float") + .Doc(R"doc( +Performs max pooling on the input. + +ksize: The size of the window for each dimension of the input tensor. +strides: The stride of the sliding window for each dimension of the + input tensor. +padding: The type of padding algorithm to use. +input: 4-D input to pool over. +output: The max pooled output tensor. +)doc"); + +REGISTER_OP("MaxPoolGrad") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Input("orig_input: float") + .Input("orig_output: float") + .Input("grad: float") + .Output("output: float") + .Doc(R"doc( +Computes gradients of the maxpooling function. + +ksize: The size of the window for each dimension of the input tensor. +strides: The stride of the sliding window for each dimension of the + input tensor. +padding: The type of padding algorithm to use. +orig_input: The original input tensor. +orig_output: The original output tensor. +grad: 4-D. Gradients w.r.t. the output of `max_pool`. +output: Gradients w.r.t. the input to `max_pool`. +)doc"); + +REGISTER_OP("MaxPoolWithArgmax") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr("Targmax: {int32, int64} = DT_INT64") + .Attr(GetPaddingAttrString()) + .Input("input: float") + .Output("output: float") + .Output("argmax: Targmax") + .Doc(R"doc( +Performs max pooling on the input and outputs both max values and indices. + +The indices in `argmax` are flattened, so that a maximum value at position +`[b, y, x, c]` becomes flattened index +`((b * height + y) * width + x) * channels + c`. + +ksize: The size of the window for each dimension of the input tensor. +strides: The stride of the sliding window for each dimension of the + input tensor. +padding: The type of padding algorithm to use. +input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. +output: The max pooled output tensor. +argmax: 4-D. The flattened indices of the max values chosen for each output. +)doc"); + +REGISTER_OP("MaxPoolGradWithArgmax") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Attr("Targmax: {int32, int64}") + .Input("input: float") + .Input("grad: float") + .Input("argmax: Targmax") + .Output("output: float") + .Doc(R"doc( +Computes gradients of the maxpooling function. + +ksize: The size of the window for each dimension of the input tensor. +strides: The stride of the sliding window for each dimension of the + input tensor. +padding: The type of padding algorithm to use. +input: The original input. +grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the + output of `max_pool`. +argmax: The indices of the maximum values chosen for each output of `max_pool`. +output: Gradients w.r.t. the input of `max_pool`. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("Relu") + .Input("features: T") + .Output("activations: T") + .Attr("T: realnumbertype") + .Doc(R"doc( +Computes rectified linear: `max(features, 0)`. +)doc"); + +REGISTER_OP("ReluGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("T: realnumbertype") + .Doc(R"doc( +Computes rectified linear gradients for a Relu operation. + +gradients: The backpropagated gradients to the corresponding Relu operation. +features: The features passed as input to the corresponding Relu operation. +backprops: The gradients: `gradients * features * (features > 0)`. +)doc"); + +REGISTER_OP("Relu6") + .Input("features: T") + .Output("activations: T") + .Attr("T: realnumbertype") + .Doc(R"doc( +Computes rectified linear 6: `min(max(features, 0), 6)`. +)doc"); + +REGISTER_OP("Relu6Grad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("T: realnumbertype") + .Doc(R"doc( +Computes rectified linear 6 gradients for a Relu6 operation. + +gradients: The backpropagated gradients to the corresponding Relu6 operation. +features: The features passed as input to the corresponding Relu6 operation. +backprops: The gradients: + `gradients * features * (features > 0) * (features < 6)`. +)doc"); + +REGISTER_OP("Softplus") + .Input("features: T") + .Output("activations: T") + .Attr("T: realnumbertype") + .Doc(R"doc( +Computes softplus: `log(exp(features) + 1)`. +)doc"); + +REGISTER_OP("SoftplusGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("T: realnumbertype") + .Doc(R"doc( +Computes softplus gradients for a softplus operation. + +gradients: The backpropagated gradients to the corresponding softplus operation. +features: The features passed as input to the corresponding softplus operation. +backprops: The gradients: `gradients / (1 + exp(-features))`. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("Softmax") + .Input("logits: T") + .Output("softmax: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Computes softmax activations. + +For each batch `i` and class `j` we have + + softmax[i, j] = exp(logits[i, j]) / sum(exp(logits[i])) + +logits: 2-D with shape `[batch_size, num_classes]`. +softmax: Same shape as `logits`. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("SoftmaxCrossEntropyWithLogits") + .Input("features: T") + .Input("labels: T") + .Output("loss: T") + .Output("backprop: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Computes softmax cross entropy cost and gradients to backpropagate. + +Inputs are the logits, not probabilities. + +features: batch_size x num_classes matrix +labels: batch_size x num_classes matrix + The caller must ensure that each batch of labels represents a valid + probability distribution. +loss: Per example loss (batch_size vector). +backprop: backpropagated gradients (batch_size x num_classes matrix). +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("InTopK") + .Attr("k: int") + .Input("predictions: float") + .Input("targets: int32") + .Output("precision: bool") + .Doc(R"doc( +Says whether the targets are in the top K predictions. + +This outputs a batch_size bool array, an entry out[i] is true if the +prediction for the target class is among the top k predictions among +all predictions for example i. Note that the behavior of InTopK differs +from the TopK op in its handling of ties; if multiple classes have the +same prediction value and straddle the top-k boundary, all of those +classes are considered to be in the top k. + +More formally, let + + \\(predictions_i\\) be the predictions for all classes for example i, + \\(targets_i\\) be the target class for example i, + \\(out_i\\) be the output for example i, + +$$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ + +predictions: A batch_size x classes tensor +targets: A batch_size vector of class ids +k: Number of top elements to look at for computing precision +precision: Computed Precision at k as a bool Tensor + +)doc"); + +REGISTER_OP("TopK") + .Attr("k: int >= 1") + .Input("input: T") + .Output("values: T") + .Output("indices: int32") + .Attr("T: realnumbertype") + .Doc(R"doc( +Returns the values and indices of the k largest elements for each row. + +\\(values_{i, j}\\) represents the j-th largest element in \\(input_i\\). + +\\(indices_{i, j}\\) gives the column index of the corresponding element, +such that \\(input_{i, indices_{i, j}} = values_{i, j}\\). If two +elements are equal, the lower-index element appears first. + +k: Number of top elements to look for within each row +input: A batch_size x classes tensor +values: A batch_size x k tensor with the k largest elements for each row, + sorted in descending order +indices: A batch_size x k tensor with the index of each value within each row + +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/no_op.cc b/tensorflow/core/ops/no_op.cc new file mode 100644 index 0000000000..52778917cb --- /dev/null +++ b/tensorflow/core/ops/no_op.cc @@ -0,0 +1,10 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("NoOp") + .Doc(R"doc( +Does nothing. Only useful as a placeholder for control edges. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc new file mode 100644 index 0000000000..7fcaa3abf1 --- /dev/null +++ b/tensorflow/core/ops/parsing_ops.cc @@ -0,0 +1,104 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("DecodeRaw") + .Input("bytes: string") + .Output("output: out_type") + .Attr("out_type: {float,double,int32,uint8,int16,int8,int64}") + .Attr("little_endian: bool = true") + .Doc(R"doc( +Reinterpret the bytes of a string as a vector of numbers. + +bytes: All the elements must have the same length. +little_endian: Whether the input bytes are in little-endian order. + Ignored for out_types that are stored in a single byte like uint8. +output: A Tensor with one more dimension than the input bytes. The + added dimension will have size equal to the length of the elements + of bytes divided by the number of bytes to represent out_type. +)doc"); + +REGISTER_OP("ParseExample") + .Input("serialized: string") + .Input("names: string") + .Input("sparse_keys: Nsparse * string") + .Input("dense_keys: Ndense * string") + .Input("dense_defaults: Tdense") + .Output("sparse_indices: Nsparse * int64") + .Output("sparse_values: sparse_types") + .Output("sparse_shapes: Nsparse * int64") + .Output("dense_values: Tdense") + .Attr("Nsparse: int >= 0") // Inferred from sparse_keys + .Attr("Ndense: int >= 0") // Inferred from dense_keys + .Attr("sparse_types: list({float,int64,string}) >= 0") + .Attr("Tdense: list({float,int64,string}) >= 0") + .Attr("dense_shapes: list(shape) >= 0") + .Doc(R"doc( +Transforms a vector of brain.Example protos (as strings) into typed tensors. + +serialized: A vector containing a batch of binary serialized Example protos. +names: A vector containing the names of the serialized protos. + May contain, for example, table key (descriptive) names for the + corresponding serialized protos. These are purely useful for debugging + purposes, and the presence of values here has no effect on the output. + May also be an empty vector if no names are available. + If non-empty, this vector must be the same length as "serialized". +dense_keys: A list of Ndense string Tensors (scalars). + The keys expected in the Examples' features associated with dense values. +dense_defaults: A list of Ndense Tensors (some may be empty). + dense_defaults[j] provides default values + when the example's feature_map lacks dense_key[j]. If an empty Tensor is + provided for dense_defaults[j], then the Feature dense_keys[j] is required. + The input type is inferred from dense_defaults[j], even when it's empty. + If dense_defaults[j] is not empty, its shape must match dense_shapes[j]. +dense_shapes: A list of Ndense shapes; the shapes of data in each Feature + given in dense_keys. + The number of elements in the Feature corresponding to dense_key[j] + must always equal dense_shapes[j].NumEntries(). + If dense_shapes[j] == (D0, D1, ..., DN) then the the shape of output + Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN): + The dense outputs are just the inputs row-stacked by batch. +sparse_keys: A list of Nsparse string Tensors (scalars). + The keys expected in the Examples' features associated with sparse values. +sparse_types: A list of Nsparse types; the data types of data in each Feature + given in sparse_keys. + Currently the ParseExample supports DT_FLOAT (FloatList), + DT_INT64 (Int64List), and DT_STRING (BytesList). +)doc"); + +REGISTER_OP("DecodeCSV") + .Input("records: string") + .Input("record_defaults: OUT_TYPE") + .Output("output: OUT_TYPE") + .Attr("OUT_TYPE: list({float,int32,int64,string})") + .Attr("field_delim: string = ','") + .Doc(R"doc( +Convert CSV records to tensors. Each column maps to one tensor. + +RFC 4180 format is expected for the CSV records. +(https://tools.ietf.org/html/rfc4180) +Note that we allow leading and trailing spaces with int or float field. + +records: Each string is a record/row in the csv and all records should have + the same format. +record_defaults: One tensor per column of the input record, with either a + scalar default value for that column or empty if the column is required. +field_delim: delimiter to separate fields in a record. +output: Each tensor will have the same shape as records. +)doc"); + +REGISTER_OP("StringToNumber") + .Input("string_tensor: string") + .Output("output: out_type") + .Attr("out_type: {float, int32} = DT_FLOAT") + .Doc(R"doc( +Converts each string in the input Tensor to the specified numeric type. + +(Note that int32 overflow results in an error while float overflow +results in a rounded value.) + +out_type: The numeric type to interpret each string in string_tensor as. +output: A Tensor of the same shape as the input string_tensor. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc new file mode 100644 index 0000000000..4be4354b85 --- /dev/null +++ b/tensorflow/core/ops/random_ops.cc @@ -0,0 +1,108 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("RandomUniform") + .Input("shape: T") + .SetIsStateful() + .Output("output: dtype") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("dtype: {float,double}") + .Attr("T: {int32, int64}") + .Doc(R"doc( +Outputs random values from a uniform distribution. + +The generated values follow a uniform distribution in the range `[0, 1)`. The +lower bound 0 is included in the range, while the upper bound 1 is excluded. + +shape: The shape of the output tensor. +dtype: The type of the output. +seed: If either `seed` or `seed2` are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: A second seed to avoid seed collision. + +output: A tensor of the specified shape filled with uniform random values. +)doc"); + +REGISTER_OP("RandomStandardNormal") + .Input("shape: T") + .SetIsStateful() + .Output("output: dtype") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("dtype: {float,double}") + .Attr("T: {int32, int64}") + .Doc(R"doc( +Outputs random values from a normal distribution. + +The generated values will have mean 0 and standard deviation 1. + +shape: The shape of the output tensor. +dtype: The type of the output. +seed: If either `seed` or `seed2` are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: A second seed to avoid seed collision. + +output: A tensor of the specified shape filled with random normal values. +)doc"); + +REGISTER_OP("TruncatedNormal") + .Input("shape: T") + .SetIsStateful() + .Output("output: dtype") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("dtype: {float,double}") + .Attr("T: {int32, int64}") + .Doc(R"doc( +Outputs random values from a truncated normal distribution. + +The generated values follow a normal distribution with mean 0 and standard +deviation 1, except that values whose magnitude is more than 2 standard +deviations from the mean are dropped and re-picked. + +shape: The shape of the output tensor. +dtype: The type of the output. +seed: If either `seed` or `seed2` are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: A second seed to avoid seed collision. + +output: A tensor of the specified shape filled with random truncated normal + values. +)doc"); + +REGISTER_OP("RandomShuffle") + .Input("value: T") + .SetIsStateful() + .Output("output: T") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("T: type") + .Doc(R"doc( +Randomly shuffles a tensor along its first dimension. + + The tensor is shuffled along dimension 0, such that each `value[j]` is mapped + to one and only one `output[i]`. For example, a mapping that might occur for a + 3x2 tensor is: + +```prettyprint +[[1, 2], [[5, 6], + [3, 4], ==> [1, 2], + [5, 6]] [3, 4]] +``` + +value: The tensor to be shuffled. +seed: If either `seed` or `seed2` are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: A second seed to avoid seed collision. + +output: A tensor of same shape and type as `value`, shuffled along its first + dimension. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/sendrecv_ops.cc b/tensorflow/core/ops/sendrecv_ops.cc new file mode 100644 index 0000000000..51158263c1 --- /dev/null +++ b/tensorflow/core/ops/sendrecv_ops.cc @@ -0,0 +1,99 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("_Send") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .Attr("send_device: string") + .Attr("send_device_incarnation: int") + .Attr("recv_device: string") + .Attr("client_terminated: bool = false") + .Doc(R"doc( +Sends the named tensor from send_device to recv_device. + +tensor: The tensor to send. +tensor_name: The name of the tensor to send. +send_device: The name of the device sending the tensor. +send_device_incarnation: The current incarnation of send_device. +recv_device: The name of the device receiving the tensor. +client_terminated: If set to true, this indicates that the node was added + to the graph as a result of a client-side feed or fetch of Tensor data, + in which case the corresponding send or recv is expected to be managed + locally by the caller. +)doc"); + +REGISTER_OP("_Recv") + .Output("tensor: tensor_type") + .Attr("tensor_type: type") + .Attr("tensor_name: string") + .Attr("send_device: string") + .Attr("send_device_incarnation: int") + .Attr("recv_device: string") + .Attr("client_terminated: bool = false") + .Doc(R"doc( +Receives the named tensor from send_device on recv_device. + +tensor: The tensor to receive. +tensor_name: The name of the tensor to receive. +send_device: The name of the device sending the tensor. +send_device_incarnation: The current incarnation of send_device. +recv_device: The name of the device receiving the tensor. +client_terminated: If set to true, this indicates that the node was added + to the graph as a result of a client-side feed or fetch of Tensor data, + in which case the corresponding send or recv is expected to be managed + locally by the caller. +)doc"); + +REGISTER_OP("_HostSend") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .Attr("send_device: string") + .Attr("send_device_incarnation: int") + .Attr("recv_device: string") + .Attr("client_terminated: bool = false") + .Doc(R"doc( +Sends the named tensor from send_device to recv_device. + +_HostSend requires its input on host memory whereas _Send requires its +input on device memory. + +tensor: The tensor to send. +tensor_name: The name of the tensor to send. +send_device: The name of the device sending the tensor. +send_device_incarnation: The current incarnation of send_device. +recv_device: The name of the device receiving the tensor. +client_terminated: If set to true, this indicates that the node was added + to the graph as a result of a client-side feed or fetch of Tensor data, + in which case the corresponding send or recv is expected to be managed + locally by the caller. +)doc"); + +REGISTER_OP("_HostRecv") + .Output("tensor: tensor_type") + .Attr("tensor_type: type") + .Attr("tensor_name: string") + .Attr("send_device: string") + .Attr("send_device_incarnation: int") + .Attr("recv_device: string") + .Attr("client_terminated: bool = false") + .Doc(R"doc( +Receives the named tensor from send_device on recv_device. + +_HostRecv requires its input on host memory whereas _Recv requires its +input on device memory. + +tensor: The tensor to receive. +tensor_name: The name of the tensor to receive. +send_device: The name of the device sending the tensor. +send_device_incarnation: The current incarnation of send_device. +recv_device: The name of the device receiving the tensor. +client_terminated: If set to true, this indicates that the node was added + to the graph as a result of a client-side feed or fetch of Tensor data, + in which case the corresponding send or recv is expected to be managed + locally by the caller. +)doc"); + +} // end namespace tensorflow diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc new file mode 100644 index 0000000000..51262373d5 --- /dev/null +++ b/tensorflow/core/ops/sparse_ops.cc @@ -0,0 +1,134 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("SparseToDense") + .Input("sparse_indices: Tindices") + .Input("output_shape: Tindices") + .Input("sparse_values: T") + .Input("default_value: T") + .Output("dense: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .Doc(R"doc( +Converts a sparse representation into a dense tensor. + +Builds an array `dense` with shape `output_shape` such that + +```prettyprint +# If sparse_indices is scalar +dense[i] = (i == sparse_indices ? sparse_values : default_value) + +# If sparse_indices is a vector, then for each i +dense[sparse_indices[i]] = sparse_values[i] + +# If sparse_indices is an n by d matrix, then for each i in [0, n) +dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] +``` + +All other values in `dense` are set to `default_value`. If `sparse_values` is a +scalar, all sparse indices are set to this single value. + +sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete + index where `sparse_values[i]` will be placed. +output_shape: 1-D. Shape of the dense output tensor. +sparse_values: 1-D. Values corresponding to each row of `sparse_indices`, + or a scalar value to be used for all sparse indices. +default_value: Scalar value to set for indices not specified in + `sparse_indices`. +dense: Dense output tensor of shape `output_shape`. +)doc"); + +REGISTER_OP("SparseConcat") + .Input("indices: N * int64") + .Input("values: N * T") + .Input("shapes: N * int64") + .Output("output_indices: int64") + .Output("output_values: T") + .Output("output_shape: int64") + .Attr("concat_dim: int >= 0") + .Attr("N: int >= 2") + .Attr("T: type") + .Doc(R"doc( +Concatenates a list of `SparseTensor` along the specified dimension. + +Concatenation is with respect to the dense versions of these sparse tensors. +It is assumed that each input is a `SparseTensor` whose elements are ordered +along increasing dimension number. + +All inputs' shapes must match, except for the concat dimension. The +`indices`, `values`, and `shapes` lists must have the same length. + +The output shape is identical to the inputs', except along the concat +dimension, where it is the sum of the inputs' sizes along that dimension. + +The output elements will be resorted to preserve the sort order along +increasing dimension number. + +This op runs in `O(M log M)` time, where `M` is the total number of non-empty +values across all inputs. This is due to the need for an internal sort in +order to concatenate efficiently across an arbitrary dimension. + +For example, if `concat_dim = 1` and the inputs are + + sp_inputs[0]: shape = [2, 3] + [0, 2]: "a" + [1, 0]: "b" + [1, 1]: "c" + + sp_inputs[1]: shape = [2, 4] + [0, 1]: "d" + [0, 2]: "e" + +then the output will be + + shape = [2, 7] + [0, 2]: "a" + [0, 4]: "d" + [0, 5]: "e" + [1, 0]: "b" + [1, 1]: "c" + +Graphically this is equivalent to doing + + [ a] concat [ d e ] = [ a d e ] + [b c ] [ ] [b c ] + +indices: 2-D. Indices of each input `SparseTensor`. +values: 1-D. Non-empty values of each `SparseTensor`. +shapes: 1-D. Shapes of each `SparseTensor`. +output_indices: 2-D. Indices of the concatenated `SparseTensor`. +output_values: 1-D. Non-empty values of the concatenated `SparseTensor`. +output_shape: 1-D. Shape of the concatenated `SparseTensor`. +concat_dim: Dimension to concatenate along. +)doc"); + +REGISTER_OP("SparseReorder") + .Input("input_indices: int64") + .Input("input_values: T") + .Input("input_shape: int64") + .Output("output_indices: int64") + .Output("output_values: T") + .Attr("T: type") + .Doc(R"doc( +Reorders a SparseTensor into the canonical, row-major ordering. + +Note that by convention, all sparse ops preserve the canonical ordering along +increasing dimension number. The only time ordering can be violated is during +manual manipulation of the indices and values vectors to add entries. + +Reordering does not affect the shape of the SparseTensor. + +If the tensor has rank `R` and `N` non-empty values, `input_indices` has +shape `[N, R]`, input_values has length `N`, and input_shape has length `R`. + +input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a + SparseTensor, possibly not in canonical ordering. +input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +input_shape: 1-D. Shape of the input SparseTensor. +output_indices: 2-D. `N x R` matrix with the same indices as input_indices, but + in canonical row-major ordering. +output_values: 1-D. `N` non-empty values corresponding to `output_indices`. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc new file mode 100644 index 0000000000..da9fd4ad08 --- /dev/null +++ b/tensorflow/core/ops/state_ops.cc @@ -0,0 +1,290 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("Variable") + .Output("ref: Ref(dtype)") + .Attr("shape: shape") + .Attr("dtype: type") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .Doc(R"doc( +Holds state in the form of a tensor that persists across steps. + +Outputs a ref to the tensor state so it may be read or modified. +TODO(zhifengc/mrry): Adds a pointer to a more detail document +about sharing states in tensorflow. + +ref: A reference to the variable tensor. +shape: The shape of the variable tensor. +dtype: The type of elements in the variable tensor. +container: If non-empty, this variable is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this variable is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +)doc"); + +REGISTER_OP("TemporaryVariable") + .Output("ref: Ref(dtype)") + .Attr("shape: shape") + .Attr("dtype: type") + .Attr("var_name: string = ''") + .SetIsStateful() + .Doc(R"doc( +Returns a tensor that may be mutated, but only persists within a single step. + +This is an experimental op for internal use only and it is possible to use this +op in unsafe ways. DO NOT USE unless you fully understand the risks. + +It is the caller's responsibility to ensure that 'ref' is eventually passed to a +matching 'DestroyTemporaryVariable' op after all other uses have completed. + +Outputs a ref to the tensor state so it may be read or modified. + + E.g. + var = state_ops._temporary_variable([1, 2], types.float_) + var_name = var.op.name + var = state_ops.assign(var, [[4.0, 5.0]]) + var = state_ops.assign_add(var, [[6.0, 7.0]]) + final = state_ops._destroy_temporary_variable(var, var_name=var_name) + +ref: A reference to the variable tensor. +shape: The shape of the variable tensor. +dtype: The type of elements in the variable tensor. +var_name: Overrides the name used for the temporary variable resource. Default +value is the name of the 'TemporaryVariable' op (which is guaranteed unique). +)doc"); + +REGISTER_OP("DestroyTemporaryVariable") + .Input("ref: Ref(T)") + .Output("value: T") + .Attr("T: type") + .Attr("var_name: string") + .Doc(R"doc( +Destroys the temporary variable and returns its final value. + +Sets output to the value of the Tensor pointed to by 'ref', then destroys +the temporary variable called 'var_name'. +All other uses of 'ref' *must* have executed before this op. +This is typically achieved by chaining the ref through each assign op, or by +using control dependencies. + +Outputs the final value of the tensor pointed to by 'ref'. + +ref: A reference to the temporary variable tensor. +var_name: Name of the temporary variable, usually the name of the matching +'TemporaryVariable' op. +)doc"); + +REGISTER_OP("Assign") + .Input("ref: Ref(T)") + .Input("value: T") + .Output("output_ref: Ref(T)") + .Attr("T: type") + .Attr("validate_shape: bool = true") + .Attr("use_locking: bool = true") + .SetAllowsUninitializedInput() + .Doc(R"doc( +Update 'ref' by assigning 'value' to it. + +This operation outputs "ref" after the assignment is done. +This makes it easier to chain operations that need to use the reset value. + +ref: Should be from a `Variable` node. May be uninitialized. +value: The value to be assigned to the variable. +validate_shape: If true, the operation will validate that the shape + of 'value' matches the shape of the Tensor being assigned to. If false, + 'ref' will take on the shape of 'value'. +use_locking: If True, the assignment will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +output_ref:= Same as "ref". Returned as a convenience for operations that want + to use the new value after the variable has been reset. +)doc"); + +REGISTER_OP("AssignAdd") + .Input("ref: Ref(T)") + .Input("value: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update 'ref' by adding 'value' to it. + +This operation outputs "ref" after the update is done. +This makes it easier to chain operations that need to use the reset value. + +ref: Should be from a `Variable` node. +value: The value to be added to the variable. +use_locking: If True, the addition will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +output_ref:= Same as "ref". Returned as a convenience for operations that want + to use the new value after the variable has been updated. +)doc"); + +REGISTER_OP("AssignSub") + .Input("ref: Ref(T)") + .Input("value: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update 'ref' by subtracting 'value' from it. + +This operation outputs "ref" after the update is done. +This makes it easier to chain operations that need to use the reset value. + +ref: Should be from a `Variable` node. +value: The value to be subtracted to the variable. +use_locking: If True, the subtraction will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +output_ref:= Same as "ref". Returned as a convenience for operations that want + to use the new value after the variable has been updated. +)doc"); + +REGISTER_OP("ScatterUpdate") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = true") + .Doc(R"doc( +Applies sparse updates to a variable reference. + +This operation computes + + # Scalar indices + ref[indices, ...] = updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] = updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] + +This operation outputs `ref` after the update is done. +This makes it easier to chain operations that need to use the reset value. + +If `indices` contains duplicate entries, lexicographically later entries +override earlier entries. + +Requires `updates.shape = indices.shape + ref.shape[1:]`. + +
+ +
+ +ref: Should be from a `Variable` node. +indices: A tensor of indices into the first dimension of `ref`. +updates: A tensor of updated values to store in `ref`. +output_ref:= Same as `ref`. Returned as a convenience for operations that want + to use the updated values after the update is done. +use_locking: If True, the assignment will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("ScatterAdd") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Adds sparse updates to a variable reference. + +This operation computes + + # Scalar indices + ref[indices, ...] += updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] += updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] + +This operation outputs `ref` after the update is done. +This makes it easier to chain operations that need to use the reset value. + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions add. + +Requires `updates.shape = indices.shape + ref.shape[1:]`. + +
+ +
+ +ref: Should be from a `Variable` node. +indices: A tensor of indices into the first dimension of `ref`. +updates: A tensor of updated values to add to `ref`. +output_ref:= Same as `ref`. Returned as a convenience for operations that want + to use the updated values after the update is done. +use_locking: If True, the addition will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("ScatterSub") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Subtracts sparse updates to a variable reference. + + # Scalar indices + ref[indices, ...] -= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] -= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] + +This operation outputs `ref` after the update is done. +This makes it easier to chain operations that need to use the reset value. + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their (negated) contributions add. + +Requires `updates.shape = indices.shape + ref.shape[1:]`. + +
+ +
+ +ref: Should be from a `Variable` node. +indices: A tensor of indices into the first dimension of `ref`. +updates: A tensor of updated values to subtract from `ref`. +output_ref:= Same as `ref`. Returned as a convenience for operations that want + to use the updated values after the update is done. +use_locking: If True, the subtraction will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("CountUpTo") + .Input("ref: Ref(T)") + .Output("output: T") + .Attr("limit: int") + .Attr("T: {int32, int64}") + .Doc(R"doc( +Increments 'ref' until it reaches 'limit'. + +This operation outputs "ref" after the update is done. This makes it +easier to chain operations that need to use the updated value. + +ref: Should be from a scalar `Variable` node. +limit: If incrementing ref would bring it above limit, instead generates an + 'OutOfRange' error. +output: A copy of the input before increment. If nothing else modifies the + input, the values produced will all be distinct. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc new file mode 100644 index 0000000000..57b471074c --- /dev/null +++ b/tensorflow/core/ops/string_ops.cc @@ -0,0 +1,21 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("StringToHashBucket") + .Input("string_tensor: string") + .Output("output: int64") + .Attr("num_buckets: int >= 1") + .Doc(R"doc( +Converts each string in the input Tensor to its hash mod by a number of buckets. + +The hash function is deterministic on the content of the string within the +process. + +Note that the hash function may change from time to time. + +num_buckets: The number of buckets. +output: A Tensor of the same shape as the input string_tensor. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc new file mode 100644 index 0000000000..5f46c871b6 --- /dev/null +++ b/tensorflow/core/ops/summary_ops.cc @@ -0,0 +1,115 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as +// inputs or outputs in various ways. + +REGISTER_OP("ScalarSummary") + .Input("tags: string") + .Input("values: T") + .Output("summary: string") + .Attr("T: {float, double}") + .Doc(R"doc( +Outputs a `Summary` protocol buffer with scalar values. + +The input `tags` and `values` must have the same shape. The generated summary +has a summary value for each tag-value pair in `tags` and `values`. + +tags: 1-D. Tags for the summary. +values: 1-D, same size as `tags. Values for the summary. +summary: Scalar. Serialized `Summary` protocol buffer. +)doc"); + +REGISTER_OP("HistogramSummary") + .Input("tag: string") + .Input("values: float") + .Output("summary: string") + .Doc(R"doc( +Outputs a `Summary` protocol buffer with a histogram. + +The generated +[`Summary`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto) +has one summary value containing a histogram for `values`. + +This op reports an `OutOfRange` error if any value is not finite. + +tag: Scalar. Tag to use for the `Summary.Value`. +values: Any shape. Values to use to build the histogram. +summary: Scalar. Serialized `Summary` protocol buffer. +)doc"); + +REGISTER_OP("ImageSummary") + .Input("tag: string") + .Input("tensor: float") + .Output("summary: string") + .Attr("max_images: int >= 1 = 3") + .Attr( + "bad_color: tensor = { dtype: DT_UINT8 " + "tensor_shape: { dim { size: 4 } } " + "int_val: 255 int_val: 0 int_val: 0 int_val: 255 }") + .Doc(R"doc( +Outputs a `Summary` protocol buffer with images. + +The summary has up to `max_images` summary values containing images. The +images are built from `tensor` which must be 4-D with shape `[batch_size, +height, width, channels]` and where `channels` can be: + +* 1: `tensor` is interpreted as Grayscale. +* 3: `tensor` is interpreted as RGB. +* 4: `tensor` is interpreted as RGBA. + +The images have the same number of channels as the input tensor. Their values +are normalized, one image at a time, to fit in the range `[0, 255]`. The +op uses two different normalization algorithms: + +* If the input values are all positive, they are rescaled so the largest one + is 255. + +* If any input value is negative, the values are shifted so input value 0.0 + is at 127. They are then rescaled so that either the smallest value is 0, + or the largest one is 255. + +The `tag` argument is a scalar `Tensor` of type `string`. It is used to +build the `tag` of the summary values: + +* If `max_images` is 1, the summary value tag is '*tag*/image'. +* If `max_images` is greater than 1, the summary value tags are + generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. + +The `bad_color` argument is the color to use in the generated images for +non-finite input values. It is a `unit8` 1-D tensor of length `channels`. +Each element must be in the range `[0, 255]` (It represents the value of a +pixel in the output image). Non-finite values in the input tensor are +replaced by this tensor in the output image. The default value is the color +red. + +tag: Scalar. Used to build the `tag` attribute of the summary values. +tensor: 4-D of shape `[batch_size, height, width, channels]` where + `channels` is 1, 3, or 4. +max_images: Max number of batch elements to generate images for. +bad_color: Color to use for pixels with non-finite values. +summary: Scalar. Serialized `Summary` protocol buffer. +)doc"); + +REGISTER_OP("MergeSummary") + .Input("inputs: N * string") + .Output("summary: string") + .Attr("N : int >= 1") + .Doc(R"doc( +Merges summaries. + +This op creates a +[`Summary`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto) +protocol buffer that contains the union of all the values in the input +summaries. + +When the Op is run, it reports an `InvalidArgument` error if multiple values +in the summaries to merge use the same tag. + +inputs: Can be of any shape. Each must contain serialized `Summary` protocol + buffers. +summary: Scalar. Serialized `Summary` protocol buffer. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc new file mode 100644 index 0000000000..e7b4e92fd5 --- /dev/null +++ b/tensorflow/core/ops/training_ops.cc @@ -0,0 +1,199 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("ApplyGradientDescent") + .Input("var: Ref(T)") + .Input("alpha: T") + .Input("delta: T") + .Output("out: Ref(T)") + .Attr("T: numbertype") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update '*var' by subtracting 'alpha' * 'delta' from it. + +var: Should be from a Variable(). +alpha: Scaling factor. Must be a scalar. +delta: The change. +out: Same as "var". +use_locking: If True, the subtraction will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("ApplyAdagrad") + .Input("var: Ref(T)") + .Input("accum: Ref(T)") + .Input("lr: T") + .Input("grad: T") + .Output("out: Ref(T)") + .Attr("T: numbertype") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update '*var' according to the adagrad scheme. + +accum += grad * grad +var -= lr * grad * (1 / sqrt(accum)) + +var: Should be from a Variable(). +accum: Should be from a Variable(). +lr: Scaling factor. Must be a scalar. +grad: The gradient. +out: Same as "var". +use_locking: If True, updating of the var and accum tensors will be protected by +a lock; otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("SparseApplyAdagrad") + .Input("var: Ref(T)") + .Input("accum: Ref(T)") + .Input("lr: T") + .Input("grad: T") + .Input("indices: Tindices") + .Output("out: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update relevant entries in '*var' and '*accum' according to the adagrad scheme. + +That is for rows we have grad for, we update var and accum as follows: +accum += grad * grad +var -= lr * grad * (1 / sqrt(accum)) + +var: Should be from a Variable(). +accum: Should be from a Variable(). +lr: Learning rate. Must be a scalar. +grad: The gradient. +indices: A vector of indices into the first dimension of var and accum. +out: Same as "var". +use_locking: If True, updating of the var and accum tensors will be protected by +a lock; otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("ApplyMomentum") + .Input("var: Ref(T)") + .Input("accum: Ref(T)") + .Input("lr: T") + .Input("grad: T") + .Input("momentum: T") + .Output("out: Ref(T)") + .Attr("T: numbertype") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update '*var' according to the momentum scheme. + +accum = accum * momentum + grad +var -= lr * accum + +var: Should be from a Variable(). +accum: Should be from a Variable(). +lr: Scaling factor. Must be a scalar. +grad: The gradient. +momentum: Momentum. Must be a scalar. +out: Same as "var". +use_locking: If True, updating of the var and accum tensors will be protected by +a lock; otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("SparseApplyMomentum") + .Input("var: Ref(T)") + .Input("accum: Ref(T)") + .Input("lr: T") + .Input("grad: T") + .Input("indices: Tindices") + .Input("momentum: T") + .Output("out: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update relevant entries in '*var' and '*accum' according to the momentum scheme. + +That is for rows we have grad for, we update var and accum as follows: + +accum = accum * momentum + grad +var -= lr * accum + +var: Should be from a Variable(). +accum: Should be from a Variable(). +lr: Learning rate. Must be a scalar. +grad: The gradient. +indices: A vector of indices into the first dimension of var and accum. +momentum: Momentum. Must be a scalar. +out: Same as "var". +use_locking: If True, updating of the var and accum tensors will be protected by +a lock; otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("ApplyAdam") + .Input("var: Ref(T)") + .Input("m: Ref(T)") + .Input("v: Ref(T)") + .Input("beta1_power: T") + .Input("beta2_power: T") + .Input("lr: T") + .Input("beta1: T") + .Input("beta2: T") + .Input("epsilon: T") + .Input("grad: T") + .Output("out: Ref(T)") + .Attr("T: numbertype") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update '*var' according to the Adam algorithm. + +lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) +m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t +v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t +variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon) + +var: Should be from a Variable(). +m: Should be from a Variable(). +v: Should be from a Variable(). +beta1_power: Must be a scalar. +beta2_power: Must be a scalar. +lr: Scaling factor. Must be a scalar. +beta1: Momentum factor. Must be a scalar. +beta2: Momentum factor. Must be a scalar. +epsilon: Ridge term. Must be a scalar. +grad: The gradient. +out: Same as "var". +use_locking: If True, updating of the var, m, and v tensors will be protected by +a lock; otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("ApplyRMSProp") + .Input("var: Ref(T)") + .Input("ms: Ref(T)") + .Input("mom: Ref(T)") + .Input("lr: T") + .Input("rho: T") + .Input("momentum: T") + .Input("epsilon: T") + .Input("grad: T") + .Output("out: Ref(T)") + .Attr("T: numbertype") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update '*var' according to the RMSProp algorithm. + +mean_square = decay * mean_square + (1-decay) * gradient ** 2 +Delta = learning_rate * gradient / sqrt(mean_square + epsilon) + +ms <- rho * ms_{t-1} + (1-rho) * grad * grad +mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +var <- var - mom + +var: Should be from a Variable(). +ms: Should be from a Variable(). +mom: Should be from a Variable(). +lr: Scaling factor. Must be a scalar. +epsilon: Ridge term. Must be a scalar. +rho: Decay rate. Must be a scalar. +grad: The gradient. +out: Same as "var". +use_locking: If True, updating of the var, m, and v tensors will be protected by +a lock; otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl new file mode 100644 index 0000000000..7cf6c274be --- /dev/null +++ b/tensorflow/core/platform/default/build_config.bzl @@ -0,0 +1,65 @@ +# Platform-specific build configurations. + +load("/google/protobuf/protobuf", "cc_proto_library") +load("/google/protobuf/protobuf", "py_proto_library") + +# Appends a suffix to a list of deps. +def tf_deps(deps, suffix): + tf_deps = [] + + # If the package name is in shorthand form (ie: does not contain a ':'), + # expand it to the full name. + for dep in deps: + tf_dep = dep + + if not ":" in dep: + dep_pieces = dep.split("/") + tf_dep += ":" + dep_pieces[len(dep_pieces) - 1] + + tf_deps += [tf_dep + suffix] + + return tf_deps + +def tf_proto_library(name, srcs = [], has_services = False, + deps = [], visibility = [], testonly = 0, + cc_api_version = 2, go_api_version = 2, + java_api_version = 2, + py_api_version = 2): + native.filegroup(name=name + "_proto_srcs", + srcs=srcs + tf_deps(deps, "_proto_srcs"), + testonly=testonly,) + + cc_proto_library(name=name + "_cc", + srcs=srcs + tf_deps(deps, "_proto_srcs"), + deps=deps, + cc_libs = ["//google/protobuf:protobuf"], + testonly=testonly, + visibility=visibility,) + + py_proto_library(name=name + "_py", + srcs=srcs + tf_deps(deps, "_proto_srcs"), + deps=deps, + py_libs = ["//google/protobuf:protobuf_python"], + testonly=testonly, + visibility=visibility,) + +def tf_proto_library_py(name, srcs=[], deps=[], visibility=[], testonly=0): + py_proto_library(name = name + "_py", + srcs = srcs, + deps = deps, + visibility = visibility, + testonly = testonly) + +def tf_additional_lib_srcs(): + return [ + "platform/default/*.h", + "platform/default/*.cc", + "platform/posix/*.h", + "platform/posix/*.cc", + ] + +def tf_additional_test_srcs(): + return ["platform/default/test_benchmark.cc"] + +def tf_kernel_tests_linkstatic(): + return 0 diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD new file mode 100644 index 0000000000..44dbc47ad1 --- /dev/null +++ b/tensorflow/core/platform/default/build_config/BUILD @@ -0,0 +1,85 @@ +# Description: +# Platform-specific build configurations. + +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("/tensorflow/tensorflow", "tf_copts") +load("/tensorflow/tensorflow", "tf_cuda_library") + +cc_library( + name = "gtest", + testonly = 1, + copts = tf_copts(), + deps = [ + "//external:gtest", + ], +) + +cc_library( + name = "tensorflow_platform_specific", + copts = tf_copts(), + linkstatic = 1, + deps = [], +) + +tf_cuda_library( + name = "stream_executor", + deps = [ + "//tensorflow/stream_executor", + ], +) + +cc_library( + name = "platformlib", + copts = tf_copts(), + deps = [ + "@jpeg_archive//:jpeg", + "@png_archive//:png", + "@re2//:re2", + "//tensorflow/core:protos_cc", + ], +) + +cc_library( + name = "protos_cc", + copts = tf_copts(), + deps = [ + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "test_main", + testonly = 1, + linkstatic = 1, + deps = [], +) + +cc_library( + name = "cuda_runtime_extra", + linkstatic = 1, + deps = [], +) + +filegroup( + name = "android_proto_lib_portable_proto", + srcs = [], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cuda", + data = [ + "//third_party/gpus/cuda:lib64/libcudart.so.7.0", + ], + linkopts = [ + "-Wl,-rpath,third_party/gpus/cuda/lib64", + ], + deps = [ + "//third_party/gpus/cuda:cudart", + ], +) diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl new file mode 100644 index 0000000000..439bf97a2c --- /dev/null +++ b/tensorflow/core/platform/default/build_config_root.bzl @@ -0,0 +1,6 @@ +# Lower-level functionality for build config. +# The functions in this file might be referred by tensorflow.bzl. They have to +# be separate to avoid cyclic references. + +def tf_cuda_tests_tags(): + return ["local"] diff --git a/tensorflow/core/platform/default/dynamic_annotations.h b/tensorflow/core/platform/default/dynamic_annotations.h new file mode 100644 index 0000000000..1705fb9955 --- /dev/null +++ b/tensorflow/core/platform/default/dynamic_annotations.h @@ -0,0 +1,9 @@ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_DYNAMIC_ANNOTATIONS_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_DYNAMIC_ANNOTATIONS_H_ + +// Do nothing for this platform +#define TF_ANNOTATE_MEMORY_IS_INITIALIZED(ptr, bytes) \ + do { \ + } while (0) + +#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_DYNAMIC_ANNOTATIONS_H_ diff --git a/tensorflow/core/platform/default/integral_types.h b/tensorflow/core/platform/default/integral_types.h new file mode 100644 index 0000000000..04aae172da --- /dev/null +++ b/tensorflow/core/platform/default/integral_types.h @@ -0,0 +1,18 @@ +#ifndef TENSORFLOW_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_ +#define TENSORFLOW_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_ + +namespace tensorflow { + +typedef signed char int8; +typedef short int16; +typedef int int32; +typedef long long int64; + +typedef unsigned char uint8; +typedef unsigned short uint16; +typedef unsigned int uint32; +typedef unsigned long long uint64; + +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_ diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc new file mode 100644 index 0000000000..8a16a537b0 --- /dev/null +++ b/tensorflow/core/platform/default/logging.cc @@ -0,0 +1,125 @@ +#include "tensorflow/core/platform/default/logging.h" + +#if defined(PLATFORM_POSIX_ANDROID) +#include +#include +#endif + +#include + +namespace tensorflow { +namespace internal { + +LogMessage::LogMessage(const char* fname, int line, int severity) + : fname_(fname), line_(line), severity_(severity) {} + +#if defined(PLATFORM_POSIX_ANDROID) +void LogMessage::GenerateLogMessage() { + int android_log_level; + switch (severity_) { + case INFO: + android_log_level = ANDROID_LOG_INFO; + break; + case WARNING: + android_log_level = ANDROID_LOG_WARN; + break; + case ERROR: + android_log_level = ANDROID_LOG_ERROR; + break; + case FATAL: + android_log_level = ANDROID_LOG_FATAL; + break; + default: + if (severity_ < INFO) { + android_log_level = ANDROID_LOG_VERBOSE; + } else { + android_log_level = ANDROID_LOG_ERROR; + } + break; + } + + std::stringstream ss; + ss << fname_ << ":" << line_ << " " << str(); + __android_log_write(android_log_level, "native", ss.str().c_str()); + + // Android logging at level FATAL does not terminate execution, so abort() + // is still required to stop the program. + if (severity_ == FATAL) { + abort(); + } +} + +#else + +void LogMessage::GenerateLogMessage() { + // TODO(jeff,sanjay): For open source version, replace this with something + // that logs through the env or something and fill in appropriate time info. + fprintf(stderr, "%c %s:%d] %s\n", "IWEF"[severity_], fname_, line_, + str().c_str()); +} +#endif + +LogMessage::~LogMessage() { GenerateLogMessage(); } + +LogMessageFatal::LogMessageFatal(const char* file, int line) + : LogMessage(file, line, FATAL) {} +LogMessageFatal::~LogMessageFatal() { + // abort() ensures we don't return (we promised we would not via + // ATTRIBUTE_NORETURN). + GenerateLogMessage(); + abort(); +} + +template <> +void MakeCheckOpValueString(std::ostream* os, const char& v) { + if (v >= 32 && v <= 126) { + (*os) << "'" << v << "'"; + } else { + (*os) << "char value " << (short)v; + } +} + +template <> +void MakeCheckOpValueString(std::ostream* os, const signed char& v) { + if (v >= 32 && v <= 126) { + (*os) << "'" << v << "'"; + } else { + (*os) << "signed char value " << (short)v; + } +} + +template <> +void MakeCheckOpValueString(std::ostream* os, const unsigned char& v) { + if (v >= 32 && v <= 126) { + (*os) << "'" << v << "'"; + } else { + (*os) << "unsigned char value " << (unsigned short)v; + } +} + +#if LANG_CXX11 +template <> +void MakeCheckOpValueString(std::ostream* os, const std::nullptr_t& p) { + (*os) << "nullptr"; +} +#endif + +CheckOpMessageBuilder::CheckOpMessageBuilder(const char* exprtext) + : stream_(new std::ostringstream) { + *stream_ << "Check failed: " << exprtext << " ("; +} + +CheckOpMessageBuilder::~CheckOpMessageBuilder() { delete stream_; } + +std::ostream* CheckOpMessageBuilder::ForVar2() { + *stream_ << " vs. "; + return stream_; +} + +string* CheckOpMessageBuilder::NewString() { + *stream_ << ")"; + return new string(stream_->str()); +} + +} // namespace internal +} // namespace tensorflow diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h new file mode 100644 index 0000000000..034178751e --- /dev/null +++ b/tensorflow/core/platform/default/logging.h @@ -0,0 +1,258 @@ +#ifndef TENSORFLOW_PLATFORM_DEFAULT_LOGGING_H_ +#define TENSORFLOW_PLATFORM_DEFAULT_LOGGING_H_ + +#include +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +const int INFO = 0; // base_logging::INFO; +const int WARNING = 1; // base_logging::WARNING; +const int ERROR = 2; // base_logging::ERROR; +const int FATAL = 3; // base_logging::FATAL; +const int NUM_SEVERITIES = 4; // base_logging::NUM_SEVERITIES; + +namespace internal { + +class LogMessage : public std::basic_ostringstream { + public: + LogMessage(const char* fname, int line, int severity); + ~LogMessage(); + + protected: + void GenerateLogMessage(); + + private: + const char* fname_; + int line_; + int severity_; +}; + +// LogMessageFatal ensures the process will exit in failure after +// logging this message. +class LogMessageFatal : public LogMessage { + public: + LogMessageFatal(const char* file, int line) TF_ATTRIBUTE_COLD; + ~LogMessageFatal() TF_ATTRIBUTE_NORETURN; +}; + +#define _TF_LOG_INFO \ + ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO) +#define _TF_LOG_WARNING \ + ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::WARNING) +#define _TF_LOG_ERROR \ + ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::ERROR) +#define _TF_LOG_FATAL \ + ::tensorflow::internal::LogMessageFatal(__FILE__, __LINE__) + +#define LOG(severity) _TF_LOG_##severity + +// TODO(jeff): Define a proper implementation of VLOG_IS_ON +#define VLOG_IS_ON(lvl) ((lvl) <= 0) + +#define VLOG(lvl) \ + if (VLOG_IS_ON(lvl)) \ + ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO) + +// CHECK dies with a fatal error if condition is not true. It is *not* +// controlled by NDEBUG, so the check will be executed regardless of +// compilation mode. Therefore, it is safe to do things like: +// CHECK(fp->Write(x) == 4) +#define CHECK(condition) \ + if (TF_PREDICT_FALSE(!(condition))) \ + LOG(FATAL) << "Check failed: " #condition " " + +// Function is overloaded for integral types to allow static const +// integrals declared in classes and not defined to be used as arguments to +// CHECK* macros. It's not encouraged though. +template +inline const T& GetReferenceableValue(const T& t) { + return t; +} +inline char GetReferenceableValue(char t) { return t; } +inline unsigned char GetReferenceableValue(unsigned char t) { return t; } +inline signed char GetReferenceableValue(signed char t) { return t; } +inline short GetReferenceableValue(short t) { return t; } +inline unsigned short GetReferenceableValue(unsigned short t) { return t; } +inline int GetReferenceableValue(int t) { return t; } +inline unsigned int GetReferenceableValue(unsigned int t) { return t; } +inline long GetReferenceableValue(long t) { return t; } +inline unsigned long GetReferenceableValue(unsigned long t) { return t; } +inline long long GetReferenceableValue(long long t) { return t; } +inline unsigned long long GetReferenceableValue(unsigned long long t) { + return t; +} + +// This formats a value for a failing CHECK_XX statement. Ordinarily, +// it uses the definition for operator<<, with a few special cases below. +template +inline void MakeCheckOpValueString(std::ostream* os, const T& v) { + (*os) << v; +} + +// Overrides for char types provide readable values for unprintable +// characters. +template <> +void MakeCheckOpValueString(std::ostream* os, const char& v); +template <> +void MakeCheckOpValueString(std::ostream* os, const signed char& v); +template <> +void MakeCheckOpValueString(std::ostream* os, const unsigned char& v); + +#if LANG_CXX11 +// We need an explicit specialization for std::nullptr_t. +template <> +void MakeCheckOpValueString(std::ostream* os, const std::nullptr_t& p); +#endif + +// A container for a string pointer which can be evaluated to a bool - +// true iff the pointer is non-NULL. +struct CheckOpString { + CheckOpString(string* str) : str_(str) {} + // No destructor: if str_ is non-NULL, we're about to LOG(FATAL), + // so there's no point in cleaning up str_. + operator bool() const { return TF_PREDICT_FALSE(str_ != NULL); } + string* str_; +}; + +// Build the error message string. Specify no inlining for code size. +template +string* MakeCheckOpString(const T1& v1, const T2& v2, + const char* exprtext) TF_ATTRIBUTE_NOINLINE; + +// A helper class for formatting "expr (V1 vs. V2)" in a CHECK_XX +// statement. See MakeCheckOpString for sample usage. Other +// approaches were considered: use of a template method (e.g., +// base::BuildCheckOpString(exprtext, base::Print, &v1, +// base::Print, &v2), however this approach has complications +// related to volatile arguments and function-pointer arguments). +class CheckOpMessageBuilder { + public: + // Inserts "exprtext" and " (" to the stream. + explicit CheckOpMessageBuilder(const char* exprtext); + // Deletes "stream_". + ~CheckOpMessageBuilder(); + // For inserting the first variable. + std::ostream* ForVar1() { return stream_; } + // For inserting the second variable (adds an intermediate " vs. "). + std::ostream* ForVar2(); + // Get the result (inserts the closing ")"). + string* NewString(); + + private: + std::ostringstream* stream_; +}; + +template +string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) { + CheckOpMessageBuilder comb(exprtext); + MakeCheckOpValueString(comb.ForVar1(), v1); + MakeCheckOpValueString(comb.ForVar2(), v2); + return comb.NewString(); +} + +// Helper functions for CHECK_OP macro. +// The (int, int) specialization works around the issue that the compiler +// will not instantiate the template version of the function on values of +// unnamed enum type - see comment below. +#define TF_DEFINE_CHECK_OP_IMPL(name, op) \ + template \ + inline string* name##Impl(const T1& v1, const T2& v2, \ + const char* exprtext) { \ + if (TF_PREDICT_TRUE(v1 op v2)) \ + return NULL; \ + else \ + return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext); \ + } \ + inline string* name##Impl(int v1, int v2, const char* exprtext) { \ + return name##Impl(v1, v2, exprtext); \ + } + +// We use the full name Check_EQ, Check_NE, etc. in case the file including +// base/logging.h provides its own #defines for the simpler names EQ, NE, etc. +// This happens if, for example, those are used as token names in a +// yacc grammar. +TF_DEFINE_CHECK_OP_IMPL(Check_EQ, + == ) // Compilation error with CHECK_EQ(NULL, x)? +TF_DEFINE_CHECK_OP_IMPL(Check_NE, != ) // Use CHECK(x == NULL) instead. +TF_DEFINE_CHECK_OP_IMPL(Check_LE, <= ) +TF_DEFINE_CHECK_OP_IMPL(Check_LT, < ) +TF_DEFINE_CHECK_OP_IMPL(Check_GE, >= ) +TF_DEFINE_CHECK_OP_IMPL(Check_GT, > ) +#undef TF_DEFINE_CHECK_OP_IMPL + +// In optimized mode, use CheckOpString to hint to compiler that +// the while condition is unlikely. +#define CHECK_OP_LOG(name, op, val1, val2) \ + while (::tensorflow::internal::CheckOpString _result = \ + ::tensorflow::internal::name##Impl( \ + ::tensorflow::internal::GetReferenceableValue(val1), \ + ::tensorflow::internal::GetReferenceableValue(val2), \ + #val1 " " #op " " #val2)) \ + ::tensorflow::internal::LogMessageFatal(__FILE__, __LINE__) << *(_result.str_) + +#define CHECK_OP(name, op, val1, val2) CHECK_OP_LOG(name, op, val1, val2) + +// CHECK_EQ/NE/... +#define CHECK_EQ(val1, val2) CHECK_OP(Check_EQ, ==, val1, val2) +#define CHECK_NE(val1, val2) CHECK_OP(Check_NE, !=, val1, val2) +#define CHECK_LE(val1, val2) CHECK_OP(Check_LE, <=, val1, val2) +#define CHECK_LT(val1, val2) CHECK_OP(Check_LT, <, val1, val2) +#define CHECK_GE(val1, val2) CHECK_OP(Check_GE, >=, val1, val2) +#define CHECK_GT(val1, val2) CHECK_OP(Check_GT, >, val1, val2) +#define CHECK_NOTNULL(val) \ + ::tensorflow::internal::CheckNotNull(__FILE__, __LINE__, \ + "'" #val "' Must be non NULL", (val)) + +#ifndef NDEBUG +// DCHECK_EQ/NE/... +#define DCHECK(condition) CHECK(condition) +#define DCHECK_EQ(val1, val2) CHECK_EQ(val1, val2) +#define DCHECK_NE(val1, val2) CHECK_NE(val1, val2) +#define DCHECK_LE(val1, val2) CHECK_LE(val1, val2) +#define DCHECK_LT(val1, val2) CHECK_LT(val1, val2) +#define DCHECK_GE(val1, val2) CHECK_GE(val1, val2) +#define DCHECK_GT(val1, val2) CHECK_GT(val1, val2) + +#else + +#define DCHECK(condition) \ + while (false && (condition)) LOG(FATAL) + +// NDEBUG is defined, so DCHECK_EQ(x, y) and so on do nothing. +// However, we still want the compiler to parse x and y, because +// we don't want to lose potentially useful errors and warnings. +// _DCHECK_NOP is a helper, and should not be used outside of this file. +#define _TF_DCHECK_NOP(x, y) \ + while (false && ((void)(x), (void)(y), 0)) LOG(FATAL) + +#define DCHECK_EQ(x, y) _TF_DCHECK_NOP(x, y) +#define DCHECK_NE(x, y) _TF_DCHECK_NOP(x, y) +#define DCHECK_LE(x, y) _TF_DCHECK_NOP(x, y) +#define DCHECK_LT(x, y) _TF_DCHECK_NOP(x, y) +#define DCHECK_GE(x, y) _TF_DCHECK_NOP(x, y) +#define DCHECK_GT(x, y) _TF_DCHECK_NOP(x, y) + +#endif + +// These are for when you don't want a CHECK failure to print a verbose +// stack trace. The implementation of CHECK* in this file already doesn't. +#define QCHECK(condition) CHECK(condition) +#define QCHECK_EQ(x, y) CHECK_EQ(x, y) +#define QCHECK_NE(x, y) CHECK_NE(x, y) +#define QCHECK_LE(x, y) CHECK_LE(x, y) +#define QCHECK_LT(x, y) CHECK_LT(x, y) +#define QCHECK_GE(x, y) CHECK_GE(x, y) +#define QCHECK_GT(x, y) CHECK_GT(x, y) + +template +T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) { + if (t == nullptr) { + LogMessageFatal(file, line) << string(exprtext); + } + return std::forward(t); +} + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_DEFAULT_LOGGING_H_ diff --git a/tensorflow/core/platform/default/mutex.h b/tensorflow/core/platform/default/mutex.h new file mode 100644 index 0000000000..b26b418e1b --- /dev/null +++ b/tensorflow/core/platform/default/mutex.h @@ -0,0 +1,33 @@ +#ifndef TENSORFLOW_PLATFORM_DEFAULT_MUTEX_H_ +#define TENSORFLOW_PLATFORM_DEFAULT_MUTEX_H_ + +#include +#include +#include + +namespace tensorflow { + +enum LinkerInitialized { LINKER_INITIALIZED }; + +// A class that wraps around the std::mutex implementation, only adding an +// additional LinkerInitialized constructor interface. +class mutex : public std::mutex { + public: + mutex() {} + // The default implementation of std::mutex is safe to use after the linker + // initializations + explicit mutex(LinkerInitialized x) {} +}; + +using std::condition_variable; +typedef std::unique_lock mutex_lock; + +inline ConditionResult WaitForMilliseconds(mutex_lock* mu, + condition_variable* cv, int64 ms) { + std::cv_status s = cv->wait_for(*mu, std::chrono::milliseconds(ms)); + return (s == std::cv_status::timeout) ? kCond_Timeout : kCond_MaybeNotified; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_DEFAULT_MUTEX_H_ diff --git a/tensorflow/core/platform/default/protobuf.h b/tensorflow/core/platform/default/protobuf.h new file mode 100644 index 0000000000..f6083c318d --- /dev/null +++ b/tensorflow/core/platform/default/protobuf.h @@ -0,0 +1,13 @@ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_PROTOBUF_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_PROTOBUF_H_ + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/text_format.h" + +namespace tensorflow { +namespace protobuf = ::google::protobuf; +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_PROTOBUF_H_ diff --git a/tensorflow/core/platform/default/stream_executor_util.h b/tensorflow/core/platform/default/stream_executor_util.h new file mode 100644 index 0000000000..d7fad4e233 --- /dev/null +++ b/tensorflow/core/platform/default/stream_executor_util.h @@ -0,0 +1,19 @@ +#ifndef TENSORFLOW_PLATFORM_DEFAULT_STREAM_EXECUTOR_UTIL_H_ +#define TENSORFLOW_PLATFORM_DEFAULT_STREAM_EXECUTOR_UTIL_H_ + +#include "tensorflow/stream_executor/lib/status.h" + +namespace tensorflow { + +namespace gpu = ::perftools::gputools; + +// On the open-source platform, stream_executor currently uses +// tensorflow::Status +inline Status FromStreamExecutorStatus( + const perftools::gputools::port::Status& s) { + return s; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_DEFAULT_STREAM_EXECUTOR_UTIL_H_ diff --git a/tensorflow/core/platform/default/test_benchmark.cc b/tensorflow/core/platform/default/test_benchmark.cc new file mode 100644 index 0000000000..4004bf026b --- /dev/null +++ b/tensorflow/core/platform/default/test_benchmark.cc @@ -0,0 +1,162 @@ +#include "tensorflow/core/platform/test_benchmark.h" + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace testing { + +static std::vector* all_benchmarks = nullptr; +static std::string label; +static int64 bytes_processed; +static int64 items_processed; +static int64 accum_time = 0; +static int64 start_time = 0; +static Env* env; + +Benchmark::Benchmark(const char* name, void (*fn)(int)) + : name_(name), num_args_(0), fn0_(fn) { + args_.push_back(-1); + Register(); +} + +Benchmark::Benchmark(const char* name, void (*fn)(int, int)) + : name_(name), num_args_(1), fn1_(fn) { + Register(); +} + +Benchmark* Benchmark::Arg(int x) { + CHECK_EQ(num_args_, 1); + args_.push_back(x); + return this; +} + +Benchmark* Benchmark::Range(int lo, int hi) { + Arg(lo); + for (int32 i = 1; i < kint32max / 8 && i < hi; i *= 8) { + Arg(i); + } + if (lo != hi) Arg(hi); + return this; +} + +void Benchmark::Run(const char* pattern) { + if (!all_benchmarks) return; + + if (StringPiece(pattern) == "all") { + pattern = ".*"; + } + + // Compute name width. + int width = 10; + string name; + for (auto b : *all_benchmarks) { + name = b->name_; + for (auto arg : b->args_) { + name.resize(b->name_.size()); + if (arg >= 0) { + strings::StrAppend(&name, "/", arg); + } + if (RE2::PartialMatch(name, pattern)) { + width = std::max(width, name.size()); + } + } + } + + printf("%-*s %10s %10s\n", width, "Benchmark", "Time(ns)", "Iterations"); + printf("%s\n", string(width + 22, '-').c_str()); + for (auto b : *all_benchmarks) { + name = b->name_; + for (auto arg : b->args_) { + name.resize(b->name_.size()); + if (arg >= 0) { + strings::StrAppend(&name, "/", arg); + } + if (!RE2::PartialMatch(name, pattern)) { + continue; + } + + int iters; + double seconds; + b->Run(arg, &iters, &seconds); + + char buf[100]; + std::string full_label = label; + if (bytes_processed > 0) { + snprintf(buf, sizeof(buf), " %.1fMB/s", + (bytes_processed * 1e-6) / seconds); + full_label += buf; + } + if (items_processed > 0) { + snprintf(buf, sizeof(buf), " %.1fM items/s", + (items_processed * 1e-6) / seconds); + full_label += buf; + } + printf("%-*s %10.0f %10d\t%s\n", width, name.c_str(), + seconds * 1e9 / iters, iters, full_label.c_str()); + } + } +} + +void Benchmark::Register() { + if (!all_benchmarks) all_benchmarks = new std::vector; + all_benchmarks->push_back(this); +} + +void Benchmark::Run(int arg, int* run_count, double* run_seconds) { + env = Env::Default(); + static const int64 kMinIters = 100; + static const int64 kMaxIters = 1000000000; + static const double kMinTime = 0.5; + int64 iters = kMinIters; + while (true) { + accum_time = 0; + start_time = env->NowMicros(); + bytes_processed = -1; + items_processed = -1; + label.clear(); + if (fn0_) { + (*fn0_)(iters); + } else { + (*fn1_)(iters, arg); + } + StopTiming(); + const double seconds = accum_time * 1e-6; + if (seconds >= kMinTime || iters >= kMaxIters) { + *run_count = iters; + *run_seconds = seconds; + return; + } + + // Update number of iterations. Overshoot by 40% in an attempt + // to succeed the next time. + double multiplier = 1.4 * kMinTime / std::max(seconds, 1e-9); + multiplier = std::min(10.0, multiplier); + if (multiplier <= 1.0) multiplier *= 2.0; + iters = std::max(multiplier * iters, iters + 1); + iters = std::min(iters, kMaxIters); + } +} + +// TODO(vrv): Add support for running a subset of benchmarks by having +// RunBenchmarks take in a spec (and maybe other options such as +// benchmark_min_time, etc). +void RunBenchmarks() { Benchmark::Run("all"); } +void SetLabel(const std::string& l) { label = l; } +void BytesProcessed(int64 n) { bytes_processed = n; } +void ItemsProcessed(int64 n) { items_processed = n; } +void StartTiming() { + if (start_time == 0) start_time = env->NowMicros(); +} +void StopTiming() { + if (start_time != 0) { + accum_time += (env->NowMicros() - start_time); + start_time = 0; + } +} +void UseRealTime() {} + +} // namespace testing +} // namespace tensorflow diff --git a/tensorflow/core/platform/default/thread_annotations.h b/tensorflow/core/platform/default/thread_annotations.h new file mode 100644 index 0000000000..fed39bf810 --- /dev/null +++ b/tensorflow/core/platform/default/thread_annotations.h @@ -0,0 +1,185 @@ +// Copyright (c) 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// --- +// +// This header file contains the macro definitions for thread safety +// annotations that allow the developers to document the locking policies +// of their multi-threaded code. The annotations can also help program +// analysis tools to identify potential thread safety issues. +// +// The primary documentation on these annotations is external: +// http://clang.llvm.org/docs/ThreadSafetyAnalysis.html +// +// The annotations are implemented using compiler attributes. +// Using the macros defined here instead of the raw attributes allows +// for portability and future compatibility. +// +// When referring to mutexes in the arguments of the attributes, you should +// use variable names or more complex expressions (e.g. my_object->mutex_) +// that evaluate to a concrete mutex object whenever possible. If the mutex +// you want to refer to is not in scope, you may use a member pointer +// (e.g. &MyClass::mutex_) to refer to a mutex in some (unknown) object. +// + +#ifndef TENSORFLOW_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_ +#define TENSORFLOW_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_ + +#if defined(__clang__) && (!defined(SWIG)) +#define THREAD_ANNOTATION_ATTRIBUTE__(x) __attribute__((x)) +#else +#define THREAD_ANNOTATION_ATTRIBUTE__(x) // no-op +#endif + +// Document if a shared variable/field needs to be protected by a mutex. +// GUARDED_BY allows the user to specify a particular mutex that should be +// held when accessing the annotated variable. GUARDED_VAR indicates that +// a shared variable is guarded by some unspecified mutex, for use in rare +// cases where a valid mutex expression cannot be specified. +#define GUARDED_BY(x) THREAD_ANNOTATION_ATTRIBUTE__(guarded_by(x)) +#define GUARDED_VAR THREAD_ANNOTATION_ATTRIBUTE__(guarded) + +// Document if the memory location pointed to by a pointer should be guarded +// by a mutex when dereferencing the pointer. PT_GUARDED_VAR is analogous to +// GUARDED_VAR. Note that a pointer variable to a shared memory location +// could itself be a shared variable. For example, if a shared global pointer +// q, which is guarded by mu1, points to a shared memory location that is +// guarded by mu2, q should be annotated as follows: +// int *q GUARDED_BY(mu1) PT_GUARDED_BY(mu2); +#define PT_GUARDED_BY(x) THREAD_ANNOTATION_ATTRIBUTE__(pt_guarded_by(x)) +#define PT_GUARDED_VAR THREAD_ANNOTATION_ATTRIBUTE__(pt_guarded) + +// Document the acquisition order between locks that can be held +// simultaneously by a thread. For any two locks that need to be annotated +// to establish an acquisition order, only one of them needs the annotation. +// (i.e. You don't have to annotate both locks with both ACQUIRED_AFTER +// and ACQUIRED_BEFORE.) +#define ACQUIRED_AFTER(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(acquired_after(__VA_ARGS__)) + +#define ACQUIRED_BEFORE(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(acquired_before(__VA_ARGS__)) + +// Document a function that expects a mutex to be held prior to entry. +// The mutex is expected to be held both on entry to and exit from the +// function. +#define EXCLUSIVE_LOCKS_REQUIRED(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(exclusive_locks_required(__VA_ARGS__)) + +#define SHARED_LOCKS_REQUIRED(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(shared_locks_required(__VA_ARGS__)) + +// Document the locks acquired in the body of the function. These locks +// cannot be held when calling this function (for instance, when the +// mutex implementation is non-reentrant). +#define LOCKS_EXCLUDED(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(locks_excluded(__VA_ARGS__)) + +// Document a function that returns a mutex without acquiring it. For example, +// a public getter method that returns a pointer to a private mutex should +// be annotated with LOCK_RETURNED. +#define LOCK_RETURNED(x) THREAD_ANNOTATION_ATTRIBUTE__(lock_returned(x)) + +// Document if a class/type is a lockable type (such as the Mutex class). +#define LOCKABLE THREAD_ANNOTATION_ATTRIBUTE__(lockable) + +// Document if a class does RAII locking (such as the MutexLock class). +// The constructor should use LOCK_FUNCTION to specify the mutex that is +// acquired, and the destructor should use UNLOCK_FUNCTION with no arguments; +// the analysis will assume that the destructor unlocks whatever the +// constructor locked. +#define SCOPED_LOCKABLE THREAD_ANNOTATION_ATTRIBUTE__(scoped_lockable) + +// Document functions that acquire a lock in the body of a function, and do +// not release it. +#define EXCLUSIVE_LOCK_FUNCTION(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(exclusive_lock_function(__VA_ARGS__)) + +#define SHARED_LOCK_FUNCTION(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(shared_lock_function(__VA_ARGS__)) + +// Document functions that expect a lock to be held on entry to the function, +// and release it in the body of the function. +#define UNLOCK_FUNCTION(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(unlock_function(__VA_ARGS__)) + +// Document functions that try to acquire a lock, and return success or failure +// (or a non-boolean value that can be interpreted as a boolean). +// The first argument should be true for functions that return true on success, +// or false for functions that return false on success. The second argument +// specifies the mutex that is locked on success. If unspecified, it is assumed +// to be 'this'. +#define EXCLUSIVE_TRYLOCK_FUNCTION(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(exclusive_trylock_function(__VA_ARGS__)) + +#define SHARED_TRYLOCK_FUNCTION(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(shared_trylock_function(__VA_ARGS__)) + +// Document functions that dynamically check to see if a lock is held, and fail +// if it is not held. +#define ASSERT_EXCLUSIVE_LOCK(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(assert_exclusive_lock(__VA_ARGS__)) + +#define ASSERT_SHARED_LOCK(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(assert_shared_lock(__VA_ARGS__)) + +// Turns off thread safety checking within the body of a particular function. +// This is used as an escape hatch for cases where either (a) the function +// is correct, but the locking is more complicated than the analyzer can handle, +// or (b) the function contains race conditions that are known to be benign. +#define NO_THREAD_SAFETY_ANALYSIS \ + THREAD_ANNOTATION_ATTRIBUTE__(no_thread_safety_analysis) + +// TS_UNCHECKED should be placed around lock expressions that are not valid +// C++ syntax, but which are present for documentation purposes. These +// annotations will be ignored by the analysis. +#define TS_UNCHECKED(x) "" + +// Disables warnings for a single read operation. This can be used to do racy +// reads of guarded data members, in cases where the race is benign. +#define TS_UNCHECKED_READ(x) \ + ::tensorflow::thread_safety_analysis::ts_unchecked_read(x) + +namespace tensorflow { +namespace thread_safety_analysis { + +// Takes a reference to a guarded data member, and returns an unguarded +// reference. +template +inline const T& ts_unchecked_read(const T& v) NO_THREAD_SAFETY_ANALYSIS { + return v; +} + +template +inline T& ts_unchecked_read(T& v) NO_THREAD_SAFETY_ANALYSIS { + return v; +} +} // namespace thread_safety_analysis +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_ diff --git a/tensorflow/core/platform/default/tracing.cc b/tensorflow/core/platform/default/tracing.cc new file mode 100644 index 0000000000..a4ddfad928 --- /dev/null +++ b/tensorflow/core/platform/default/tracing.cc @@ -0,0 +1,37 @@ +#include "tensorflow/core/platform/tracing.h" + +#include + +namespace tensorflow { +namespace port { + +void Tracing::RegisterEvent(EventCategory id, const char* name) { + // TODO(opensource): implement +} + +void Tracing::Initialize() {} + +static bool TryGetEnv(const char* name, const char** value) { + *value = getenv(name); + return *value != nullptr && (*value)[0] != '\0'; +} + +const char* Tracing::LogDir() { + const char* dir; + if (TryGetEnv("TEST_TMPDIR", &dir)) return dir; + if (TryGetEnv("TMP", &dir)) return dir; + if (TryGetEnv("TMPDIR", &dir)) return dir; + dir = "/tmp"; + if (access(dir, R_OK | W_OK | X_OK) == 0) return dir; + return "."; // Default to current directory. +} + +static bool DoInit() { + Tracing::Initialize(); + return true; +} + +static const bool dummy = DoInit(); + +} // namespace port +} // namespace tensorflow diff --git a/tensorflow/core/platform/default/tracing_impl.h b/tensorflow/core/platform/default/tracing_impl.h new file mode 100644 index 0000000000..e2f5d3cb3f --- /dev/null +++ b/tensorflow/core/platform/default/tracing_impl.h @@ -0,0 +1,44 @@ +#ifndef TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_ +#define TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_ + +// Stub implementations of tracing functionality. + +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/tracing.h" + +namespace tensorflow { +namespace port { + +// Definitions that do nothing for platforms that don't have underlying thread +// tracing support. +#define TRACELITERAL(a) \ + do { \ + } while (0) +#define TRACESTRING(s) \ + do { \ + } while (0) +#define TRACEPRINTF(format, ...) \ + do { \ + } while (0) + +inline uint64 Tracing::UniqueId() { return random::New64(); } +inline bool Tracing::IsActive() { return false; } +inline void Tracing::RegisterCurrentThread(const char* name) {} + +// Posts an atomic threadscape event with the supplied category and arg. +inline void Tracing::RecordEvent(EventCategory category, uint64 arg) { + // TODO(opensource): Implement +} + +inline Tracing::ScopedActivity::ScopedActivity(EventCategory category, + uint64 arg) + : enabled_(false), region_id_(category_id_[category]) {} + +inline Tracing::ScopedActivity::~ScopedActivity() {} + +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_ diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc new file mode 100644 index 0000000000..3e3c0ad74e --- /dev/null +++ b/tensorflow/core/platform/env.cc @@ -0,0 +1,129 @@ +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +Env::~Env() {} + +RandomAccessFile::~RandomAccessFile() {} + +WritableFile::~WritableFile() {} + +Thread::~Thread() {} + +EnvWrapper::~EnvWrapper() {} + +Status ReadFileToString(Env* env, const string& fname, string* data) { + data->clear(); + RandomAccessFile* file; + Status s = env->NewRandomAccessFile(fname, &file); + if (!s.ok()) { + return s; + } + int64 offset = 0; + static const int kBufferSize = 8192; + char* space = new char[kBufferSize]; + while (true) { + StringPiece fragment; + s = file->Read(offset, kBufferSize, &fragment, space); + if (!s.ok()) { + if (errors::IsOutOfRange(s)) { // No more bytes, but not an error + s = Status::OK(); + data->append(fragment.data(), fragment.size()); + } + break; + } + offset += fragment.size(); + data->append(fragment.data(), fragment.size()); + if (fragment.empty()) { + break; + } + } + delete[] space; + delete file; + return s; +} + +Status WriteStringToFile(Env* env, const string& fname, + const StringPiece& data) { + WritableFile* file; + Status s = env->NewWritableFile(fname, &file); + if (!s.ok()) { + return s; + } + s = file->Append(data); + if (s.ok()) { + s = file->Close(); + } + delete file; + return s; +} + +// A ZeroCopyInputStream on a RandomAccessFile. +namespace { +class FileStream : public ::tensorflow::protobuf::io::ZeroCopyInputStream { + public: + explicit FileStream(RandomAccessFile* file) : file_(file), pos_(0) {} + + void BackUp(int count) override { pos_ -= count; } + bool Skip(int count) override { + pos_ += count; + return true; + } + int64 ByteCount() const override { return pos_; } + Status status() const { return status_; } + + bool Next(const void** data, int* size) override { + StringPiece result; + Status s = file_->Read(pos_, kBufSize, &result, scratch_); + if (result.empty()) { + status_ = s; + return false; + } + pos_ += result.size(); + *data = result.data(); + *size = result.size(); + return true; + } + + private: + static const int kBufSize = 512 << 10; + + RandomAccessFile* file_; + int64 pos_; + Status status_; + char scratch_[kBufSize]; +}; + +} // namespace + +Status ReadBinaryProto(Env* env, const string& fname, + ::tensorflow::protobuf::MessageLite* proto) { + RandomAccessFile* file; + auto s = env->NewRandomAccessFile(fname, &file); + if (!s.ok()) { + return s; + } + std::unique_ptr file_holder(file); + std::unique_ptr stream(new FileStream(file)); + + // TODO(jiayq): the following coded stream is for debugging purposes to allow + // one to parse arbitrarily large messages for MessageLite. One most likely + // doesn't want to put protobufs larger than 64MB on Android, so we should + // eventually remove this and quit loud when a large protobuf is passed in. + ::tensorflow::protobuf::io::CodedInputStream coded_stream(stream.get()); + // Total bytes hard limit / warning limit are set to 1GB and 512MB + // respectively. + coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); + + if (!proto->ParseFromCodedStream(&coded_stream)) { + s = stream->status(); + if (s.ok()) { + s = Status(error::DATA_LOSS, "Parse error"); + } + } + return s; +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc new file mode 100644 index 0000000000..be15c4a5cb --- /dev/null +++ b/tensorflow/core/platform/env_test.cc @@ -0,0 +1,31 @@ +#include "tensorflow/core/public/env.h" + +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include + +namespace tensorflow { + +struct EnvTest {}; + +TEST(EnvTest, ReadFileToString) { + Env* env = Env::Default(); + const string dir = testing::TmpDir(); + for (const int length : {0, 1, 1212, 2553, 4928, 8196, 9000}) { + const string filename = io::JoinPath(dir, strings::StrCat("file", length)); + + // Write a file with the given length + string input(length, 0); + for (int i = 0; i < length; i++) input[i] = i; + WriteStringToFile(env, filename, input); + + // Read the file back and check equality + string output; + TF_CHECK_OK(ReadFileToString(env, filename, &output)); + CHECK_EQ(length, output.size()); + CHECK_EQ(input, output); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/init_main.h b/tensorflow/core/platform/init_main.h new file mode 100644 index 0000000000..ce3d1fbc2f --- /dev/null +++ b/tensorflow/core/platform/init_main.h @@ -0,0 +1,16 @@ +#ifndef TENSORFLOW_PLATFORM_INIT_MAIN_H_ +#define TENSORFLOW_PLATFORM_INIT_MAIN_H_ + +namespace tensorflow { +namespace port { + +// Platform-specific initialization routine that may be invoked by a +// main() program that uses TensorFlow. +// +// Default implementation does nothing. +void InitMain(const char* usage, int* argc, char*** argv); + +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_INIT_MAIN_H_ diff --git a/tensorflow/core/platform/integral_types_test.cc b/tensorflow/core/platform/integral_types_test.cc new file mode 100644 index 0000000000..067787a9f4 --- /dev/null +++ b/tensorflow/core/platform/integral_types_test.cc @@ -0,0 +1,33 @@ +#include "tensorflow/core/platform/port.h" + +#include + +namespace tensorflow { +namespace { + +TEST(IntegralTypes, Basic) { + EXPECT_EQ(1, sizeof(int8)); + EXPECT_EQ(2, sizeof(int16)); + EXPECT_EQ(4, sizeof(int32)); + EXPECT_EQ(8, sizeof(int64)); + + EXPECT_EQ(1, sizeof(uint8)); + EXPECT_EQ(2, sizeof(uint16)); + EXPECT_EQ(4, sizeof(uint32)); + EXPECT_EQ(8, sizeof(uint64)); +} + +TEST(IntegralTypes, MinAndMaxConstants) { + EXPECT_EQ(static_cast(kint8min), static_cast(kint8max) + 1); + EXPECT_EQ(static_cast(kint16min), static_cast(kint16max) + 1); + EXPECT_EQ(static_cast(kint32min), static_cast(kint32max) + 1); + EXPECT_EQ(static_cast(kint64min), static_cast(kint64max) + 1); + + EXPECT_EQ(0, static_cast(kuint8max + 1)); + EXPECT_EQ(0, static_cast(kuint16max + 1)); + EXPECT_EQ(0, static_cast(kuint32max + 1)); + EXPECT_EQ(0, static_cast(kuint64max + 1)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/platform/logging.h b/tensorflow/core/platform/logging.h new file mode 100644 index 0000000000..66caf22ede --- /dev/null +++ b/tensorflow/core/platform/logging.h @@ -0,0 +1,12 @@ +#ifndef TENSORFLOW_PLATFORM_LOGGING_H_ +#define TENSORFLOW_PLATFORM_LOGGING_H_ + +#include "tensorflow/core/platform/port.h" // To pick up PLATFORM_define + +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) +#include "base/logging.h" +#else +#include "tensorflow/core/platform/default/logging.h" +#endif + +#endif // TENSORFLOW_PLATFORM_LOGGING_H_ diff --git a/tensorflow/core/platform/logging_test.cc b/tensorflow/core/platform/logging_test.cc new file mode 100644 index 0000000000..03d734ae95 --- /dev/null +++ b/tensorflow/core/platform/logging_test.cc @@ -0,0 +1,76 @@ +#include "tensorflow/core/platform/logging.h" +#include + +namespace tensorflow { + +TEST(Logging, Log) { + LOG(INFO) << "Hello"; + LOG(INFO) << "Another log message"; + LOG(ERROR) << "Error message"; + VLOG(1) << "A VLOG message"; + VLOG(2) << "A higher VLOG message"; +} + +TEST(Logging, CheckChecks) { + CHECK(true); + CHECK(7 > 5); + string a("abc"); + string b("xyz"); + CHECK_EQ(a, a); + CHECK_NE(a, b); + CHECK_EQ(3, 3); + CHECK_NE(4, 3); + CHECK_GT(4, 3); + CHECK_GE(3, 3); + CHECK_LT(2, 3); + CHECK_LE(2, 3); + + DCHECK(true); + DCHECK(7 > 5); + DCHECK_EQ(a, a); + DCHECK_NE(a, b); + DCHECK_EQ(3, 3); + DCHECK_NE(4, 3); + DCHECK_GT(4, 3); + DCHECK_GE(3, 3); + DCHECK_LT(2, 3); + DCHECK_LE(2, 3); +} + +TEST(LoggingDeathTest, FailedChecks) { + string a("abc"); + string b("xyz"); + const char* p_const = "hello there"; + const char* p_null_const = nullptr; + char mybuf[10]; + char* p_non_const = mybuf; + char* p_null = nullptr; + CHECK_NOTNULL(p_const); + CHECK_NOTNULL(p_non_const); + + ASSERT_DEATH(CHECK(false), "false"); + ASSERT_DEATH(CHECK(9 < 7), "9 < 7"); + ASSERT_DEATH(CHECK_EQ(a, b), "a == b"); + ASSERT_DEATH(CHECK_EQ(3, 4), "3 == 4"); + ASSERT_DEATH(CHECK_NE(3, 3), "3 != 3"); + ASSERT_DEATH(CHECK_GT(2, 3), "2 > 3"); + ASSERT_DEATH(CHECK_GE(2, 3), "2 >= 3"); + ASSERT_DEATH(CHECK_LT(3, 2), "3 < 2"); + ASSERT_DEATH(CHECK_LE(3, 2), "3 <= 2"); + ASSERT_DEATH(CHECK(false), "false"); + ASSERT_DEATH(printf("%s", CHECK_NOTNULL(p_null)), "Must be non NULL"); + ASSERT_DEATH(printf("%s", CHECK_NOTNULL(p_null_const)), "Must be non NULL"); +#ifndef NDEBUG + ASSERT_DEATH(DCHECK(9 < 7), "9 < 7"); + ASSERT_DEATH(DCHECK(9 < 7), "9 < 7"); + ASSERT_DEATH(DCHECK_EQ(a, b), "a == b"); + ASSERT_DEATH(DCHECK_EQ(3, 4), "3 == 4"); + ASSERT_DEATH(DCHECK_NE(3, 3), "3 != 3"); + ASSERT_DEATH(DCHECK_GT(2, 3), "2 > 3"); + ASSERT_DEATH(DCHECK_GE(2, 3), "2 >= 3"); + ASSERT_DEATH(DCHECK_LT(3, 2), "3 < 2"); + ASSERT_DEATH(DCHECK_LE(3, 2), "3 <= 2"); +#endif +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/port.h b/tensorflow/core/platform/port.h new file mode 100644 index 0000000000..fef20f7753 --- /dev/null +++ b/tensorflow/core/platform/port.h @@ -0,0 +1,228 @@ +#ifndef TENSORFLOW_PLATFORM_PORT_H_ +#define TENSORFLOW_PLATFORM_PORT_H_ + +#include +#include + +#if !defined(PLATFORM_POSIX) && !defined(PLATFORM_GOOGLE) && \ + !defined(PLATFORM_POSIX_ANDROID) && !defined(PLATFORM_GOOGLE_ANDROID) + +// Choose which platform we are on. +#if defined(ANDROID) || defined(__ANDROID__) +#define PLATFORM_POSIX_ANDROID +#elif defined(__APPLE__) +#define PLATFORM_POSIX +#else +// If no platform specified, use: +#define PLATFORM_POSIX +#endif + +#endif + +// Define tensorflow::string to refer to appropriate platform specific type. +namespace tensorflow { +#if defined(PLATFORM_GOOGLE) +using ::string; +#else +using std::string; +#endif +} // namespace tensorflow + +namespace tensorflow { +enum ConditionResult { kCond_Timeout, kCond_MaybeNotified }; +} // namespace tensorflow + +// Include appropriate platform-dependent implementations of mutex etc. +#if defined(PLATFORM_GOOGLE) +#include "tensorflow/core/platform/google/integral_types.h" +#include "tensorflow/core/platform/google/mutex.h" +#include "tensorflow/core/platform/google/dynamic_annotations.h" +#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \ + defined(PLATFORM_GOOGLE_ANDROID) +#include "tensorflow/core/platform/default/integral_types.h" +#include "tensorflow/core/platform/default/mutex.h" +#include "tensorflow/core/platform/default/dynamic_annotations.h" +#else +#error Define the appropriate PLATFORM_ macro for this platform +#endif + +namespace tensorflow { + +static const uint8 kuint8max = ((uint8)0xFF); +static const uint16 kuint16max = ((uint16)0xFFFF); +static const uint32 kuint32max = ((uint32)0xFFFFFFFF); +static const uint64 kuint64max = ((uint64)0xFFFFFFFFFFFFFFFFull); +static const int8 kint8min = ((int8)~0x7F); +static const int8 kint8max = ((int8)0x7F); +static const int16 kint16min = ((int16)~0x7FFF); +static const int16 kint16max = ((int16)0x7FFF); +static const int32 kint32min = ((int32)~0x7FFFFFFF); +static const int32 kint32max = ((int32)0x7FFFFFFF); +static const int64 kint64min = ((int64)~0x7FFFFFFFFFFFFFFFll); +static const int64 kint64max = ((int64)0x7FFFFFFFFFFFFFFFll); + +// A typedef for a uint64 used as a short fingerprint. +typedef uint64 Fprint; + +// The mutex library included above defines: +// class mutex; +// class mutex_lock; +// class condition_variable; +// It also defines the following: + +// Like "cv->wait(*mu)", except that it only waits for up to "ms" milliseconds. +// +// Returns kCond_Timeout if the timeout expired without this +// thread noticing a signal on the condition variable. Otherwise may +// return either kCond_Timeout or kCond_MaybeNotified +ConditionResult WaitForMilliseconds(mutex_lock* mu, condition_variable* cv, + int64 ms); +} // namespace tensorflow + +namespace tensorflow { +namespace port { + +// TODO(jeff,sanjay): Make portable +static const bool kLittleEndian = true; + +// TODO(jeff,sanjay): Find appropriate places for all the code below. +// Possible places for any particular item below: +// (a) Here, so it gets reimplemented on every platform +// (b) Env +// (c) config.h (auto-generated by autotools?) +// (d) macros.h +// ... + +// Return the hostname of the machine on which this process is running +string Hostname(); + +// Returns an estimate of the number of schedulable CPUs for this +// process. Usually, it's constant throughout the lifetime of a +// process, but it might change if the underlying cluster management +// software can change it dynamically. +int NumSchedulableCPUs(); + +// Some platforms require that filenames be of a certain form when +// used for logging. This function is invoked to allow platforms to +// adjust the filename used for logging appropriately, if necessary +// (most ports can just do nothing). If any changes are necessary, the +// implementation should mutate "*filename" appropriately. +void AdjustFilenameForLogging(string* filename); + +// Aligned allocation/deallocation +void* aligned_malloc(size_t size, int minimum_alignment); +void aligned_free(void* aligned_memory); + +// Prefetching support +// +// Defined behavior on some of the uarchs: +// PREFETCH_HINT_T0: +// prefetch to all levels of the hierarchy (except on p4: prefetch to L2) +// PREFETCH_HINT_NTA: +// p4: fetch to L2, but limit to 1 way (out of the 8 ways) +// core: skip L2, go directly to L1 +// k8 rev E and later: skip L2, can go to either of the 2-ways in L1 +enum PrefetchHint { + PREFETCH_HINT_T0 = 3, // More temporal locality + PREFETCH_HINT_T1 = 2, + PREFETCH_HINT_T2 = 1, // Less temporal locality + PREFETCH_HINT_NTA = 0 // No temporal locality +}; +template +void prefetch(const void* x); + +// Snappy compression/decompression support +bool Snappy_Compress(const char* input, size_t length, string* output); + +bool Snappy_GetUncompressedLength(const char* input, size_t length, + size_t* result); +bool Snappy_Uncompress(const char* input, size_t length, char* output); + +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L +// Define this to 1 if the code is compiled in C++11 mode; leave it +// undefined otherwise. Do NOT define it to 0 -- that causes +// '#ifdef LANG_CXX11' to behave differently from '#if LANG_CXX11'. +#define LANG_CXX11 1 +#endif + +// Compiler attributes +#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG) +// Compiler supports GCC-style attributes +#define TF_ATTRIBUTE_NORETURN __attribute__((noreturn)) +#define TF_ATTRIBUTE_NOINLINE __attribute__((noinline)) +#define TF_ATTRIBUTE_UNUSED __attribute__((unused)) +#define TF_ATTRIBUTE_COLD __attribute__((cold)) +#define TF_PACKED __attribute__((packed)) +#define TF_MUST_USE_RESULT __attribute__((warn_unused_result)) +#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) \ + __attribute__((__format__(__printf__, string_index, first_to_check))) +#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) \ + __attribute__((__format__(__scanf__, string_index, first_to_check))) + +#else +// Non-GCC equivalents +#define TF_ATTRIBUTE_NORETURN +#define TF_ATTRIBUTE_NOINLINE +#define TF_ATTRIBUTE_UNUSED +#define TF_ATTRIBUTE_COLD +#define TF_MUST_USE_RESULT +#define TF_PACKED +#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) +#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) +#endif + +// GCC can be told that a certain branch is not likely to be taken (for +// instance, a CHECK failure), and use that information in static analysis. +// Giving it this information can help it optimize for the common case in +// the absence of better information (ie. -fprofile-arcs). +// +#if defined(COMPILER_GCC3) +#define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0)) +#define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) +#else +#define TF_PREDICT_FALSE(x) x +#define TF_PREDICT_TRUE(x) x +#endif + +// --------------------------------------------------------------------------- +// Inline implementations of some performance-critical methods +// --------------------------------------------------------------------------- +template +inline void prefetch(const void* x) { +#if defined(__llvm__) || defined(COMPILER_GCC) + __builtin_prefetch(x, 0, hint); +#else +// You get no effect. Feel free to add more sections above. +#endif +} + +// A macro to disallow the copy constructor and operator= functions +// This is usually placed in the private: declarations for a class. +#define TF_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&) = delete; \ + void operator=(const TypeName&) = delete + +// The TF_ARRAYSIZE(arr) macro returns the # of elements in an array arr. +// +// The expression TF_ARRAYSIZE(a) is a compile-time constant of type +// size_t. +#define TF_ARRAYSIZE(a) \ + ((sizeof(a) / sizeof(*(a))) / \ + static_cast(!(sizeof(a) % sizeof(*(a))))) + +#if defined(__clang__) && defined(LANG_CXX11) && defined(__has_warning) +#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough") +#define TF_FALLTHROUGH_INTENDED [[clang::fallthrough]] // NOLINT +#endif +#endif + +#ifndef TF_FALLTHROUGH_INTENDED +#define TF_FALLTHROUGH_INTENDED \ + do { \ + } while (0) +#endif + +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_PORT_H_ diff --git a/tensorflow/core/platform/port_test.cc b/tensorflow/core/platform/port_test.cc new file mode 100644 index 0000000000..8cf1c30aa3 --- /dev/null +++ b/tensorflow/core/platform/port_test.cc @@ -0,0 +1,48 @@ +#include "tensorflow/core/platform/port.h" +#include +#include "tensorflow/core/lib/core/threadpool.h" +#include + +namespace tensorflow { +namespace port { + +TEST(Port, AlignedMalloc) { + for (size_t alignment = 1; alignment <= 1 << 20; alignment <<= 1) { + void* p = aligned_malloc(1, alignment); + ASSERT_TRUE(p != NULL) << "aligned_malloc(1, " << alignment << ")"; + uintptr_t pval = reinterpret_cast(p); + EXPECT_EQ(pval % alignment, 0); + aligned_free(p); + } +} + +TEST(ConditionVariable, WaitForMilliseconds_Timeout) { + mutex m; + mutex_lock l(m); + condition_variable cv; + time_t start = time(NULL); + EXPECT_EQ(WaitForMilliseconds(&l, &cv, 3000), kCond_Timeout); + time_t finish = time(NULL); + EXPECT_GE(finish - start, 3); +} + +TEST(ConditionVariable, WaitForMilliseconds_Signalled) { + thread::ThreadPool pool(Env::Default(), "test", 1); + mutex m; + mutex_lock l(m); + condition_variable cv; + time_t start = time(NULL); + // Sleep for just 1 second then notify. We have a timeout of 3 secs, + // so the condition variable will notice the cv signal before the timeout. + pool.Schedule([&m, &cv]() { + sleep(1); + mutex_lock l(m); + cv.notify_all(); + }); + EXPECT_EQ(WaitForMilliseconds(&l, &cv, 3000), kCond_MaybeNotified); + time_t finish = time(NULL); + EXPECT_LT(finish - start, 3); +} + +} // namespace port +} // namespace tensorflow diff --git a/tensorflow/core/platform/posix/env.cc b/tensorflow/core/platform/posix/env.cc new file mode 100644 index 0000000000..6ba2010005 --- /dev/null +++ b/tensorflow/core/platform/posix/env.cc @@ -0,0 +1,385 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace { + +error::Code ErrnoToCode(int err_number) { + error::Code code; + switch (err_number) { + case 0: + code = error::OK; + break; + case EINVAL: // Invalid argument + case ENAMETOOLONG: // Filename too long + case E2BIG: // Argument list too long + case EDESTADDRREQ: // Destination address required + case EDOM: // Mathematics argument out of domain of function + case EFAULT: // Bad address + case EILSEQ: // Illegal byte sequence + case ENOPROTOOPT: // Protocol not available + case ENOSTR: // Not a STREAM + case ENOTSOCK: // Not a socket + case ENOTTY: // Inappropriate I/O control operation + case EPROTOTYPE: // Protocol wrong type for socket + case ESPIPE: // Invalid seek + code = error::INVALID_ARGUMENT; + break; + case ETIMEDOUT: // Connection timed out + case ETIME: // Timer expired + code = error::DEADLINE_EXCEEDED; + break; + case ENODEV: // No such device + case ENOENT: // No such file or directory + case ENXIO: // No such device or address + case ESRCH: // No such process + code = error::NOT_FOUND; + break; + case EEXIST: // File exists + case EADDRNOTAVAIL: // Address not available + case EALREADY: // Connection already in progress + code = error::ALREADY_EXISTS; + break; + case EPERM: // Operation not permitted + case EACCES: // Permission denied + case EROFS: // Read only file system + code = error::PERMISSION_DENIED; + break; + case ENOTEMPTY: // Directory not empty + case EISDIR: // Is a directory + case ENOTDIR: // Not a directory + case EADDRINUSE: // Address already in use + case EBADF: // Invalid file descriptor + case EBUSY: // Device or resource busy + case ECHILD: // No child processes + case EISCONN: // Socket is connected + case ENOTBLK: // Block device required + case ENOTCONN: // The socket is not connected + case EPIPE: // Broken pipe + case ESHUTDOWN: // Cannot send after transport endpoint shutdown + case ETXTBSY: // Text file busy + code = error::FAILED_PRECONDITION; + break; + case ENOSPC: // No space left on device + case EDQUOT: // Disk quota exceeded + case EMFILE: // Too many open files + case EMLINK: // Too many links + case ENFILE: // Too many open files in system + case ENOBUFS: // No buffer space available + case ENODATA: // No message is available on the STREAM read queue + case ENOMEM: // Not enough space + case ENOSR: // No STREAM resources + case EUSERS: // Too many users + code = error::RESOURCE_EXHAUSTED; + break; + case EFBIG: // File too large + case EOVERFLOW: // Value too large to be stored in data type + case ERANGE: // Result too large + code = error::OUT_OF_RANGE; + break; + case ENOSYS: // Function not implemented + case ENOTSUP: // Operation not supported + case EAFNOSUPPORT: // Address family not supported + case EPFNOSUPPORT: // Protocol family not supported + case EPROTONOSUPPORT: // Protocol not supported + case ESOCKTNOSUPPORT: // Socket type not supported + case EXDEV: // Improper link + code = error::UNIMPLEMENTED; + break; + case EAGAIN: // Resource temporarily unavailable + case ECONNREFUSED: // Connection refused + case ECONNABORTED: // Connection aborted + case ECONNRESET: // Connection reset + case EINTR: // Interrupted function call + case EHOSTDOWN: // Host is down + case EHOSTUNREACH: // Host is unreachable + case ENETDOWN: // Network is down + case ENETRESET: // Connection aborted by network + case ENETUNREACH: // Network unreachable + case ENOLCK: // No locks available + case ENOLINK: // Link has been severed +#if !defined(__APPLE__) + case ENONET: // Machine is not on the network +#endif + code = error::UNAVAILABLE; + break; + case EDEADLK: // Resource deadlock avoided + case ESTALE: // Stale file handle + code = error::ABORTED; + break; + case ECANCELED: // Operation cancelled + code = error::CANCELLED; + break; + // NOTE: If you get any of the following (especially in a + // reproducible way) and can propose a better mapping, + // please email the owners about updating this mapping. + case EBADMSG: // Bad message + case EIDRM: // Identifier removed + case EINPROGRESS: // Operation in progress + case EIO: // I/O error + case ELOOP: // Too many levels of symbolic links + case ENOEXEC: // Exec format error + case ENOMSG: // No message of the desired type + case EPROTO: // Protocol error + case EREMOTE: // Object is remote + code = error::UNKNOWN; + break; + default: { + code = error::UNKNOWN; + break; + } + } + return code; +} + +static Status IOError(const string& context, int err_number) { + auto code = ErrnoToCode(err_number); + if (code == error::UNKNOWN) { + return Status(ErrnoToCode(err_number), + context + "; " + strerror(err_number)); + } else { + return Status(ErrnoToCode(err_number), context); + } +} + +// pread() based random-access +class PosixRandomAccessFile : public RandomAccessFile { + private: + string filename_; + int fd_; + + public: + PosixRandomAccessFile(const string& fname, int fd) + : filename_(fname), fd_(fd) {} + ~PosixRandomAccessFile() override { close(fd_); } + + Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override { + Status s; + char* dst = scratch; + while (n > 0 && s.ok()) { + ssize_t r = pread(fd_, dst, n, static_cast(offset)); + if (r > 0) { + dst += r; + n -= r; + offset += r; + } else if (r == 0) { + s = Status(error::OUT_OF_RANGE, "Read less bytes than requested"); + } else if (errno == EINTR || errno == EAGAIN) { + // Retry + } else { + s = IOError(filename_, errno); + } + } + *result = StringPiece(scratch, dst - scratch); + return s; + } +}; + +class PosixWritableFile : public WritableFile { + private: + string filename_; + FILE* file_; + + public: + PosixWritableFile(const string& fname, FILE* f) + : filename_(fname), file_(f) {} + + ~PosixWritableFile() override { + if (file_ != NULL) { + // Ignoring any potential errors + fclose(file_); + } + } + + Status Append(const StringPiece& data) override { + size_t r = fwrite(data.data(), 1, data.size(), file_); + if (r != data.size()) { + return IOError(filename_, errno); + } + return Status::OK(); + } + + Status Close() override { + Status result; + if (fclose(file_) != 0) { + result = IOError(filename_, errno); + } + file_ = NULL; + return result; + } + + Status Flush() override { + if (fflush(file_) != 0) { + return IOError(filename_, errno); + } + return Status::OK(); + } + + Status Sync() override { + Status s; + if (fflush(file_) != 0) { + s = IOError(filename_, errno); + } + return s; + } +}; + +class StdThread : public Thread { + public: + // name and thread_options are both ignored. + StdThread(const ThreadOptions& thread_options, const string& name, + std::function fn) + : thread_(fn) {} + ~StdThread() { thread_.join(); } + + private: + std::thread thread_; +}; + +class PosixEnv : public Env { + public: + PosixEnv() {} + + ~PosixEnv() override { LOG(FATAL) << "Env::Default() must not be destroyed"; } + + Status NewRandomAccessFile(const string& fname, + RandomAccessFile** result) override { + *result = NULL; + Status s; + int fd = open(fname.c_str(), O_RDONLY); + if (fd < 0) { + s = IOError(fname, errno); + } else { + *result = new PosixRandomAccessFile(fname, fd); + } + return s; + } + + Status NewWritableFile(const string& fname, WritableFile** result) override { + Status s; + FILE* f = fopen(fname.c_str(), "w"); + if (f == NULL) { + *result = NULL; + s = IOError(fname, errno); + } else { + *result = new PosixWritableFile(fname, f); + } + return s; + } + + Status NewAppendableFile(const string& fname, + WritableFile** result) override { + Status s; + FILE* f = fopen(fname.c_str(), "a"); + if (f == NULL) { + *result = NULL; + s = IOError(fname, errno); + } else { + *result = new PosixWritableFile(fname, f); + } + return s; + } + + bool FileExists(const string& fname) override { + return access(fname.c_str(), F_OK) == 0; + } + + Status GetChildren(const string& dir, std::vector* result) override { + result->clear(); + DIR* d = opendir(dir.c_str()); + if (d == NULL) { + return IOError(dir, errno); + } + struct dirent* entry; + while ((entry = readdir(d)) != NULL) { + StringPiece basename = entry->d_name; + if ((basename != ".") && (basename != "..")) { + result->push_back(entry->d_name); + } + } + closedir(d); + return Status::OK(); + } + + Status DeleteFile(const string& fname) override { + Status result; + if (unlink(fname.c_str()) != 0) { + result = IOError(fname, errno); + } + return result; + } + + Status CreateDir(const string& name) override { + Status result; + if (mkdir(name.c_str(), 0755) != 0) { + result = IOError(name, errno); + } + return result; + } + + Status DeleteDir(const string& name) override { + Status result; + if (rmdir(name.c_str()) != 0) { + result = IOError(name, errno); + } + return result; + } + + Status GetFileSize(const string& fname, uint64* size) override { + Status s; + struct stat sbuf; + if (stat(fname.c_str(), &sbuf) != 0) { + *size = 0; + s = IOError(fname, errno); + } else { + *size = sbuf.st_size; + } + return s; + } + + Status RenameFile(const string& src, const string& target) override { + Status result; + if (rename(src.c_str(), target.c_str()) != 0) { + result = IOError(src, errno); + } + return result; + } + + uint64 NowMicros() override { + struct timeval tv; + gettimeofday(&tv, NULL); + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; + } + + void SleepForMicroseconds(int micros) override { usleep(micros); } + + Thread* StartThread(const ThreadOptions& thread_options, const string& name, + std::function fn) override { + return new StdThread(thread_options, name, fn); + } +}; + +} // namespace +#if defined(PLATFORM_POSIX) || defined(__ANDROID__) +Env* Env::Default() { + static Env* default_env = new PosixEnv; + return default_env; +} +#endif + +} // namespace tensorflow diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc new file mode 100644 index 0000000000..b4a1570ef9 --- /dev/null +++ b/tensorflow/core/platform/posix/port.cc @@ -0,0 +1,92 @@ +#include "tensorflow/core/platform/port.h" +#if defined(__linux) && !defined(__ANDROID__) +#include +#endif +#include +#include +#include +#include +#ifdef SNAPPY +#include +#endif + +namespace tensorflow { +namespace port { + +void InitMain(const char* usage, int* argc, char*** argv) {} + +string Hostname() { + char hostname[1024]; + gethostname(hostname, sizeof hostname); + hostname[sizeof hostname - 1] = 0; + return string(hostname); +} + +int NumSchedulableCPUs() { +#if defined(__linux) && !defined(__ANDROID__) + cpu_set_t cpuset; + if (sched_getaffinity(0, sizeof(cpu_set_t), &cpuset) == 0) { + return CPU_COUNT(&cpuset); + } + perror("sched_getaffinity"); +#endif + const int kDefaultCores = 4; // Semi-conservative guess + fprintf(stderr, "can't determine number of CPU cores: assuming %d\n", + kDefaultCores); + return kDefaultCores; +} + +void* aligned_malloc(size_t size, int minimum_alignment) { +#if defined(__ANDROID__) + return memalign(minimum_alignment, size); +#else // !__ANDROID__ + void* ptr = NULL; + // posix_memalign requires that the requested alignment be at least + // sizeof(void*). In this case, fall back on malloc which should return + // memory aligned to at least the size of a pointer. + const int required_alignment = sizeof(void*); + if (minimum_alignment < required_alignment) return malloc(size); + if (posix_memalign(&ptr, minimum_alignment, size) != 0) + return NULL; + else + return ptr; +#endif +} + +void aligned_free(void* aligned_memory) { free(aligned_memory); } + +void AdjustFilenameForLogging(string* filename) { + // Nothing to do +} + +bool Snappy_Compress(const char* input, size_t length, string* output) { +#ifdef SNAPPY + output->resize(snappy::MaxCompressedLength(length)); + size_t outlen; + snappy::RawCompress(input, length, &(*output)[0], &outlen); + output->resize(outlen); + return true; +#else + return false; +#endif +} + +bool Snappy_GetUncompressedLength(const char* input, size_t length, + size_t* result) { +#ifdef SNAPPY + return snappy::GetUncompressedLength(input, length, result); +#else + return false; +#endif +} + +bool Snappy_Uncompress(const char* input, size_t length, char* output) { +#ifdef SNAPPY + return snappy::RawUncompress(input, length, output); +#else + return false; +#endif +} + +} // namespace port +} // namespace tensorflow diff --git a/tensorflow/core/platform/protobuf.h b/tensorflow/core/platform/protobuf.h new file mode 100644 index 0000000000..3a166b3973 --- /dev/null +++ b/tensorflow/core/platform/protobuf.h @@ -0,0 +1,29 @@ +#ifndef TENSORFLOW_PLATFORM_PROTOBUF_H_ +#define TENSORFLOW_PLATFORM_PROTOBUF_H_ + +// Import whatever namespace protobuf comes from into the +// ::tensorflow::protobuf namespace. +// +// TensorFlow code should the ::tensorflow::protobuf namespace to refer +// to all protobuf APIs. + +#include "tensorflow/core/platform/port.h" +#if defined(PLATFORM_GOOGLE) +#include "tensorflow/core/platform/google/protobuf.h" +#elif defined(PLATFORM_GOOGLE_ANDROID) +#include "tensorflow/core/platform/google/protobuf_android.h" +#else +#include "tensorflow/core/platform/default/protobuf.h" +#endif + +namespace tensorflow { +// Parses a protocol buffer contained in a string in the binary wire format. +// Returns true on success. Note: Unlike protobuf's builtin ParseFromString, +// this function has no size restrictions on the total size of the encoded +// protocol buffer. +bool ParseProtoUnlimited(protobuf::Message* proto, const string& serialized); +bool ParseProtoUnlimited(protobuf::Message* proto, const void* serialized, + size_t size); +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_PROTOBUF_H_ diff --git a/tensorflow/core/platform/protobuf_util.cc b/tensorflow/core/platform/protobuf_util.cc new file mode 100644 index 0000000000..b698d3f0c2 --- /dev/null +++ b/tensorflow/core/platform/protobuf_util.cc @@ -0,0 +1,17 @@ +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +bool ParseProtoUnlimited(protobuf::Message* proto, const string& serialized) { + return ParseProtoUnlimited(proto, serialized.data(), serialized.size()); +} + +bool ParseProtoUnlimited(protobuf::Message* proto, const void* serialized, + size_t size) { + protobuf::io::CodedInputStream coded_stream( + reinterpret_cast(serialized), size); + coded_stream.SetTotalBytesLimit(INT_MAX, INT_MAX); + return proto->ParseFromCodedStream(&coded_stream); +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/regexp.h b/tensorflow/core/platform/regexp.h new file mode 100644 index 0000000000..ef46a7aca5 --- /dev/null +++ b/tensorflow/core/platform/regexp.h @@ -0,0 +1,33 @@ +#ifndef TENSORFLOW_PLATFORM_REGEXP_H_ +#define TENSORFLOW_PLATFORM_REGEXP_H_ + +#include "tensorflow/core/platform/port.h" + +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) +#include "third_party/re2/re2.h" +namespace tensorflow { +typedef ::StringPiece RegexpStringPiece; +} // namespace tensorflow + +#else + +#include "external/re2/re2/re2.h" +namespace tensorflow { +typedef re2::StringPiece RegexpStringPiece; +} // namespace tensorflow + +#endif + +namespace tensorflow { + +// Conversion to/from the appropriate StringPiece type for using in RE2 +inline RegexpStringPiece ToRegexpStringPiece(tensorflow::StringPiece sp) { + return RegexpStringPiece(sp.data(), sp.size()); +} +inline tensorflow::StringPiece FromRegexpStringPiece(RegexpStringPiece sp) { + return tensorflow::StringPiece(sp.data(), sp.size()); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_REGEXP_H_ diff --git a/tensorflow/core/platform/stream_executor_util.h b/tensorflow/core/platform/stream_executor_util.h new file mode 100644 index 0000000000..a6640fb26d --- /dev/null +++ b/tensorflow/core/platform/stream_executor_util.h @@ -0,0 +1,12 @@ +#ifndef TENSORFLOW_PLATFORM_STREAM_EXECUTOR_UTIL_H_ +#define TENSORFLOW_PLATFORM_STREAM_EXECUTOR_UTIL_H_ + +#include "tensorflow/core/platform/port.h" + +#if defined(PLATFORM_GOOGLE) +#include "tensorflow/core/platform/google/stream_executor_util.h" +#else +#include "tensorflow/core/platform/default/stream_executor_util.h" +#endif + +#endif // TENSORFLOW_PLATFORM_STREAM_EXECUTOR_UTIL_H_ diff --git a/tensorflow/core/platform/tensor_coding.cc b/tensorflow/core/platform/tensor_coding.cc new file mode 100644 index 0000000000..a5cbd0ab44 --- /dev/null +++ b/tensorflow/core/platform/tensor_coding.cc @@ -0,0 +1,53 @@ +#include "tensorflow/core/platform/tensor_coding.h" + +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace port { + +void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out) { + out->assign(src.data(), src.size()); +} + +void EncodeStringList(const string* strings, int64 n, string* out) { + out->clear(); + for (int i = 0; i < n; ++i) { + core::PutVarint32(out, strings[i].size()); + } + for (int i = 0; i < n; ++i) { + out->append(strings[i]); + } +} + +bool DecodeStringList(const string& src, string* strings, int64 n) { + std::vector sizes(n); + StringPiece reader(src); + int64 tot = 0; + for (auto& v : sizes) { + if (!core::GetVarint32(&reader, &v)) return false; + tot += v; + } + if (tot != static_cast(reader.size())) { + return false; + } + + string* data = strings; + for (int64 i = 0; i < n; ++i, ++data) { + auto size = sizes[i]; + if (size > reader.size()) { + return false; + } + data->assign(reader.data(), size); + reader.remove_prefix(size); + } + + return true; +} + +void CopyFromArray(string* s, const char* base, size_t bytes) { + s->assign(base, bytes); +} + +} // namespace port +} // namespace tensorflow diff --git a/tensorflow/core/platform/tensor_coding.h b/tensorflow/core/platform/tensor_coding.h new file mode 100644 index 0000000000..6bb9991895 --- /dev/null +++ b/tensorflow/core/platform/tensor_coding.h @@ -0,0 +1,40 @@ +// Helper routines for encoding/decoding tensor contents. +#ifndef TENSORFLOW_PLATFORM_TENSOR_CODING_H_ +#define TENSORFLOW_PLATFORM_TENSOR_CODING_H_ + +#include +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/port.h" + +#ifdef PLATFORM_GOOGLE +#include "tensorflow/core/platform/google/cord_coding.h" +#endif + +namespace tensorflow { +namespace port { + +// Store src contents in *out. If backing memory for src is shared with *out, +// will ref obj during the call and will arrange to unref obj when no +// longer needed. +void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out); + +// Copy contents of src to dst[0,src.size()-1]. +inline void CopyToArray(const string& src, char* dst) { + memcpy(dst, src.data(), src.size()); +} + +// Store encoding of strings[0..n-1] in *out. +void EncodeStringList(const string* strings, int64 n, string* out); + +// Decode n strings from src and store in strings[0..n-1]. +// Returns true if successful, false on parse error. +bool DecodeStringList(const string& src, string* strings, int64 n); + +// Assigns base[0..bytes-1] to *s +void CopyFromArray(string* s, const char* base, size_t bytes); + +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_TENSOR_CODING_H_ diff --git a/tensorflow/core/platform/test.cc b/tensorflow/core/platform/test.cc new file mode 100644 index 0000000000..21c6905683 --- /dev/null +++ b/tensorflow/core/platform/test.cc @@ -0,0 +1,39 @@ +#include "tensorflow/core/platform/port.h" + +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX_ANDROID) || \ + defined(PLATFORM_GOOGLE_ANDROID) +#include "testing/base/public/googletest.h" +#endif + +namespace tensorflow { +namespace testing { + +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX_ANDROID) || \ + defined(PLATFORM_GOOGLE_ANDROID) +string TmpDir() { return FLAGS_test_tmpdir; } +int RandomSeed() { return FLAGS_test_random_seed; } +#else +string TmpDir() { + // 'bazel test' sets TEST_TMPDIR + const char* env = getenv("TEST_TMPDIR"); + if (env && env[0] != '\0') { + return env; + } + env = getenv("TMPDIR"); + if (env && env[0] != '\0') { + return env; + } + return "/tmp"; +} +int RandomSeed() { + const char* env = getenv("TEST_RANDOM_SEED"); + int result; + if (env && sscanf(env, "%d", &result) == 1) { + return result; + } + return 301; +} +#endif + +} // namespace testing +} // namespace tensorflow diff --git a/tensorflow/core/platform/test.h b/tensorflow/core/platform/test.h new file mode 100644 index 0000000000..ea16fe1442 --- /dev/null +++ b/tensorflow/core/platform/test.h @@ -0,0 +1,17 @@ +#ifndef TENSORFLOW_PLATFORM_TEST_H_ +#define TENSORFLOW_PLATFORM_TEST_H_ + +namespace tensorflow { +namespace testing { + +// Return a temporary directory suitable for temporary testing files. +string TmpDir(); + +// Return a random number generator seed to use in randomized tests. +// Returns the same value for the lifetime of the process. +int RandomSeed(); + +} // namespace testing +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_TEST_H_ diff --git a/tensorflow/core/platform/test_benchmark.h b/tensorflow/core/platform/test_benchmark.h new file mode 100644 index 0000000000..8c8a92a519 --- /dev/null +++ b/tensorflow/core/platform/test_benchmark.h @@ -0,0 +1,58 @@ +// Simple benchmarking facility. +#ifndef TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_ +#define TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_ + +#include "tensorflow/core/platform/port.h" + +#if defined(PLATFORM_GOOGLE) +#include "testing/base/public/benchmark.h" + +#else +#define BENCHMARK(n) \ + static ::tensorflow::testing::Benchmark* TF_BENCHMARK_CONCAT( \ + __benchmark_, n, __LINE__) TF_ATTRIBUTE_UNUSED = \ + (new ::tensorflow::testing::Benchmark(#n, (n))) +#define TF_BENCHMARK_CONCAT(a, b, c) TF_BENCHMARK_CONCAT2(a, b, c) +#define TF_BENCHMARK_CONCAT2(a, b, c) a##b##c + +#endif // PLATFORM_GOOGLE + +namespace tensorflow { +namespace testing { + +#if defined(PLATFORM_GOOGLE) +using ::testing::Benchmark; +#else +class Benchmark { + public: + Benchmark(const char* name, void (*fn)(int)); + Benchmark(const char* name, void (*fn)(int, int)); + + Benchmark* Arg(int x); + Benchmark* Range(int lo, int hi); + static void Run(const char* pattern); + + private: + string name_; + int num_args_; + std::vector args_; + void (*fn0_)(int) = nullptr; + void (*fn1_)(int, int) = nullptr; + + void Register(); + void Run(int arg, int* run_count, double* run_seconds); +}; +#endif + +void RunBenchmarks(); +void SetLabel(const std::string& label); +void BytesProcessed(int64); +void ItemsProcessed(int64); +void StartTiming(); +void StopTiming(); +void UseRealTime(); + +} // namespace testing +} // namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_ diff --git a/tensorflow/core/platform/test_main.cc b/tensorflow/core/platform/test_main.cc new file mode 100644 index 0000000000..11230c3f7b --- /dev/null +++ b/tensorflow/core/platform/test_main.cc @@ -0,0 +1,31 @@ +// A program with a main that is suitable for unittests, including those +// that also define microbenchmarks. Based on whether the user specified +// the --benchmark_filter flag which specifies which benchmarks to run, +// we will either run benchmarks or run the gtest tests in the program. + +#include + +#include "tensorflow/core/platform/port.h" + +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX_ANDROID) || \ + defined(PLATFORM_GOOGLE_ANDROID) +// main() is supplied by gunit_main +#else +#include "gtest/gtest.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/test_benchmark.h" + +GTEST_API_ int main(int argc, char** argv) { + std::cout << "Running main() from test_main.cc\n"; + + testing::InitGoogleTest(&argc, argv); + for (int i = 1; i < argc; i++) { + if (tensorflow::StringPiece(argv[i]).starts_with("--benchmarks=")) { + const char* pattern = argv[i] + strlen("--benchmarks="); + tensorflow::testing::Benchmark::Run(pattern); + return 0; + } + } + return RUN_ALL_TESTS(); +} +#endif diff --git a/tensorflow/core/platform/thread_annotations.h b/tensorflow/core/platform/thread_annotations.h new file mode 100644 index 0000000000..cb8040eed6 --- /dev/null +++ b/tensorflow/core/platform/thread_annotations.h @@ -0,0 +1,14 @@ +#ifndef TENSORFLOW_PLATFORM_THREAD_ANNOTATIONS_H_ +#define TENSORFLOW_PLATFORM_THREAD_ANNOTATIONS_H_ + +#include "tensorflow/core/platform/port.h" + +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) +#include "base/thread_annotations.h" +#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) +#include "tensorflow/core/platform/default/thread_annotations.h" +#else +#error Define the appropriate PLATFORM_ macro for this platform +#endif + +#endif // TENSORFLOW_PLATFORM_THREAD_ANNOTATIONS_H_ diff --git a/tensorflow/core/platform/tracing.cc b/tensorflow/core/platform/tracing.cc new file mode 100644 index 0000000000..a4cb92dee4 --- /dev/null +++ b/tensorflow/core/platform/tracing.cc @@ -0,0 +1,135 @@ +#include "tensorflow/core/platform/tracing.h" + +#include +#include +#include +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +StepStatsCollector::StepStatsCollector(StepStats* ss) : step_stats_(ss) {} + +void StepStatsCollector::Save(const string& device, NodeExecStats* nt) { + VLOG(1) << "Save dev " << device << " nt " << nt; + { + mutex_lock l(mu_); + DeviceStepStats* dss = nullptr; + // Slow linear scan, but it should only be called + // by a Worker in a context with < ~10 devices. + // TODO(tucker): consider adding a std::unordered_map. + for (auto& ds : *step_stats_->mutable_dev_stats()) { + if (ds.device() == device) { + dss = &ds; + break; + } + } + if (dss == nullptr) { + dss = step_stats_->add_dev_stats(); + dss->set_device(device); + } + nt->Swap(dss->add_node_stats()); + } + delete nt; +} + +void StepStatsCollector::Swap(StepStats* ss) { + mutex_lock l(mu_); + CHECK(step_stats_); + ss->Swap(step_stats_); +} + +namespace port { + +int32 Tracing::category_id_[kEventCategoryMax]; +uint64 Tracing::event_mask_ = 0; +std::map* Tracing::name_map_ = new std::map; + +// This needs to be kept in sync with the EventCategory enumeration. +const char* Tracing::EventCategoryString(EventCategory category) { + switch (category) { + case EventCategory::kScheduleClosure: + return "ScheduleClosure"; + case EventCategory::kRunClosure: + return "RunClosure"; + case EventCategory::kCompute: + return "Compute"; + case EventCategory::kEventCategoryMax: + return "EventCategoryMax"; + } + return "Unknown"; +} + +// This function allows the user to specify arbitrary subsets of the +// supported Threadscape events and activities. +bool Tracing::ParseEventMask(const char* flagname, const string& value) { + VLOG(1) << flagname << " set to " << value; + int64 new_mask = 0; + std::vector events = + str_util::Split(value, ',', str_util::SkipEmpty()); + for (string name : events) { + bool clear = false; + int64 mask = 0; + if (name[0] == '!') { + // invert the sense of the flag + clear = true; + name = name.substr(1); + } + if (name == "ALL") { + mask = ~0; + } else { + auto it = name_map_->find(name); + int32 id; + if (it == name_map_->end()) { + id = -1; + } else { + id = it->second; + } + if (id < 0) { + LOG(ERROR) << "Can't parse event mask name " << name; + return false; + } + mask = 1 << id; + } + if (clear) { + new_mask &= ~mask; + } else { + new_mask |= mask; + } + } + // parsing was successful; set the permanent event mask + event_mask_ = new_mask; + return true; +} + +static std::atomic tracing_engine; + +void Tracing::RegisterEngine(Engine* e) { + tracing_engine.store(e, std::memory_order_release); +} + +static Tracing::Engine* engine() { + return tracing_engine.load(std::memory_order_acquire); +} + +Tracing::Engine::~Engine() {} +Tracing::Engine::Annotation::~Annotation() {} +Tracing::Engine::Tracer::~Tracer() {} + +Tracing::ScopedAnnotation::ScopedAnnotation(StringPiece name) { + auto e = engine(); + if (e) { + annotation_.reset(e->PushAnnotation(name)); + } +} + +Tracing::TraceMe::TraceMe(StringPiece name) { + auto e = engine(); + if (e) { + tracer_.reset(e->StartTracing(name)); + } +} + +} // namespace port +} // namespace tensorflow diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h new file mode 100644 index 0000000000..2b53a64cf1 --- /dev/null +++ b/tensorflow/core/platform/tracing.h @@ -0,0 +1,205 @@ +#ifndef TENSORFLOW_PLATFORM_TRACING_H_ +#define TENSORFLOW_PLATFORM_TRACING_H_ + +// Tracing interface + +#include +#include + +#include "tensorflow/core/platform/port.h" // Must be first +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +class NodeExecStats; +class StepStats; + +class StepStatsCollector { + public: + explicit StepStatsCollector(StepStats* ss); + + void Save(const string& device, NodeExecStats* nt); + + void Swap(StepStats* ss); + + private: + friend class StepStatsMgr; + mutex mu_; + StepStats* step_stats_ GUARDED_BY(mu_); +}; + +namespace port { + +class Tracing { + public: + // This enumeration contains the identifiers of all TensorFlow + // threadscape events and code regions. Threadscape assigns its + // own identiers at runtime when we register our events and we + // cannot know in advance what IDs it will choose. The "RecordEvent" + // method and "ScopedActivity" use these event IDs for consistency + // and remap them to threadscape IDs at runtime. This enum is limited + // to 64 values since we use a bitmask to configure which events are + // enabled. It must also be kept in step with the code in + // "Tracing::EventCategoryString". + enum EventCategory { + kScheduleClosure = 0, + kRunClosure = 1, + kCompute = 2, + kEventCategoryMax = 3 // sentinel - keep last + }; + // Note: We currently only support up to 64 categories. + static_assert(kEventCategoryMax <= 64, "only support up to 64 events"); + + // Called by main programs to initialize tracing facilities + static void Initialize(); + + // Return the pathname of the directory where we are writing log files. + static const char* LogDir(); + + // Returns a non-zero identifier which can be used to correlate + // related events. + static inline uint64 UniqueId(); + + // Returns true if a trace is in progress. Can be used to reduce tracing + // overheads in fast-path code. + static inline bool IsActive(); + + // Associate name with the current thread. + static void RegisterCurrentThread(const char* name); + + // Posts an event with the supplied category and arg. + static void RecordEvent(EventCategory category, uint64 arg); + + // Traces a region of code. Posts a tracing "EnterCodeRegion" event + // when created and an "ExitCodeRegion" event when destroyed. + class ScopedActivity { + public: + explicit ScopedActivity(EventCategory category, uint64 arg); + ~ScopedActivity(); + + private: + const bool enabled_; + const int32 region_id_; + + TF_DISALLOW_COPY_AND_ASSIGN(ScopedActivity); + }; + + // Trace collection engine can be registered with this module. + // If no engine is registered, ScopedAnnotation and TraceMe are no-ops. + class Engine; + static void RegisterEngine(Engine*); + + // Forward declaration of the GPU utility classes. + class ScopedAnnotation; + class TraceMe; + + private: + friend class TracingTest; + + static void RegisterEvent(EventCategory id, const char* name); + static const char* EventCategoryString(EventCategory category); + + // + // Parses event mask expressions in 'value' of the form: + // expr ::= (,)* + // term ::= | "!" + // event ::= "ALL" | | + // wait_event ::= "ENewSession" | "ECloseSession" | ... + // other_event ::= "Send" | "Wait" | ... + // ALL denotes all events, turns on tracing for this event, and + // ! turns off tracing for this event. + // If the expression can be parsed correctly it returns true and sets + // the event_mask_. Otherwise it returns false and the event_mask_ is left + // unchanged. + static bool ParseEventMask(const char* flagname, const string& value); + + // Bit mask of enabled trace categories. + static uint64 event_mask_; + + // Records the mappings between Threadscape IDs and the "EventCategory" enum. + static int32 category_id_[kEventCategoryMax]; + static std::map* name_map_; +}; + +// Trace collection engine that actually implements collection. +class Tracing::Engine { + public: + Engine() {} + virtual ~Engine(); + + // Represents an active annotation. + class Annotation { + public: + Annotation() {} + virtual ~Annotation(); + }; + + // Represents an active trace. + class Tracer { + public: + Tracer() {} + virtual ~Tracer(); + }; + + private: + friend class ScopedAnnotation; + friend class TraceMe; + + // Register the specified name as an annotation on the current thread. + // Caller should delete the result to remove the annotation. + // Annotations from the same thread are destroyed in a LIFO manner. + // May return nullptr if annotations are not supported. + virtual Annotation* PushAnnotation(StringPiece name) = 0; + + // Start tracing under the specified label. Caller should delete the + // result to stop tracing. + // May return nullptr if tracing is not supported. + virtual Tracer* StartTracing(StringPiece label) = 0; +}; + +// This class permits a user to apply annotation on kernels and memcpys +// when launching them. While an annotation is in scope, all activities +// within that scope get their names replaced by the annotation. The kernel +// name replacement is done when constructing the protobuf for sending out to +// a client (e.g., the stubby requestor) for both API and Activity records. +// +// Ownership: The creator of ScopedAnnotation assumes ownership of the object. +// +// Usage: { +// ScopedAnnotation annotation("first set of kernels"); +// Kernel1<<>>; +// LaunchKernel2(); // Which eventually launches a cuda kernel. +// } +// In the above scenario, the GPUProf UI would show 2 kernels with the name +// "first set of kernels" executing -- they will appear as the same kernel. +class Tracing::ScopedAnnotation { + public: + explicit ScopedAnnotation(StringPiece name); + + private: + std::unique_ptr annotation_; +}; + +// TODO(opensource): clean up the scoped classes for GPU tracing. +// This class permits user-specified (CPU) tracing activities. A trace +// activity is started when an object of this class is created and stopped +// when the object is destroyed. +class Tracing::TraceMe { + public: + explicit TraceMe(StringPiece name); + + private: + std::unique_ptr tracer_; +}; + +} // namespace port +} // namespace tensorflow + +#if defined(PLATFORM_GOOGLE) && !defined(ANDROID) && !defined(__ANDROID__) +#include "tensorflow/core/platform/google/tracing_impl.h" +#else +#include "tensorflow/core/platform/default/tracing_impl.h" +#endif + +#endif // TENSORFLOW_PLATFORM_TRACING_H_ diff --git a/tensorflow/core/public/README.md b/tensorflow/core/public/README.md new file mode 100644 index 0000000000..b1afff87de --- /dev/null +++ b/tensorflow/core/public/README.md @@ -0,0 +1,90 @@ +# TensorFlow + +TensorFlow is a computational dataflow graph library. + +## Getting started + + +### Python API example +The following is an example python code to do a simple matrix multiply +of two constants and get the result from a locally-running TensorFlow +process. + +First, bring in the following dependency: + +//third_party/tensorflow/core/public:tensorflow_py + +to get the python TensorFlow API. If you intend to run TensorFlow within +the same process, link in the following to the same binary: + +//third_party/tensorflow/core/public:tensorflow_std_ops + +to get the standard set of op implementations. Then: + +```python +import tensorflow as tf + +with tf.Session("local"): + input1 = tf.Constant(1.0, shape=[1, 1], name="input1") + input2 = tf.Constant(2.0, shape=[1, 1], name="input2") + output = tf.MatMul(input1, input2) + + # Run graph and fetch the output + result = output.eval() + print result +``` + +### C++ API Example + +If you are running TensorFlow locally, link your binary with + +//third_party/tensorflow/core/public:tensorflow_local + +and link in the operation implementations you want to supported, e.g., + +//third_party/tensorflow/core/public:tensorflow_std_ops + +An example program to take a GraphDef and run it using TensorFlow +using the C++ Session API: + +```c++ +#include +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/tensor.h" + +int main(int argc, char** argv) { + // Construct your graph. + tensorflow::GraphDef graph = ...; + + // Create a Session running TensorFlow locally in process. + std::unique_ptr session(tensorflow::NewSession({})); + + // Initialize the session with the graph. + tensorflow::Status s = session->Create(graph); + if (!s.ok()) { ... } + + // Specify the 'feeds' of your network if needed. + std::vector> inputs; + + // Run the session, asking for the first output of "my_output". + std::vector outputs; + s = session->Run(inputs, {"my_output:0"}, {}, &outputs); + if (!s.ok()) { ... } + + // Do something with your outputs + auto output_vector = outputs[0].vec(); + if (output_vector(0) > 0.5) { ... } + + // Close the session. + session->Close(); + + return 0; +} +``` + +For a more fully-featured C++ example, see +`tensorflow/cc/tutorials/example_trainer.cc` diff --git a/tensorflow/core/public/env.h b/tensorflow/core/public/env.h new file mode 100644 index 0000000000..4024525859 --- /dev/null +++ b/tensorflow/core/public/env.h @@ -0,0 +1,273 @@ +#ifndef TENSORFLOW_PUBLIC_ENV_H_ +#define TENSORFLOW_PUBLIC_ENV_H_ + +#include +#include +#include +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +class RandomAccessFile; +class Thread; +class ThreadOptions; +class WritableFile; + +/// \brief An interface used by the tensorflow implementation to +/// access operating system functionality like the filesystem etc. +/// +/// Callers may wish to provide a custom Env object to get fine grain +/// control. +/// +/// All Env implementations are safe for concurrent access from +/// multiple threads without any external synchronization. +class Env { + public: + Env() {} + virtual ~Env(); + + /// \brief Returns a default environment suitable for the current operating + /// system. + /// + /// Sophisticated users may wish to provide their own Env + /// implementation instead of relying on this default environment. + /// + /// The result of Default() belongs to this library and must never be deleted. + static Env* Default(); + + /// \brief Creates a brand new random access read-only file with the + /// specified name. + + /// On success, stores a pointer to the new file in + /// *result and returns OK. On failure stores NULL in *result and + /// returns non-OK. If the file does not exist, returns a non-OK + /// status. + /// + /// The returned file may be concurrently accessed by multiple threads. + virtual Status NewRandomAccessFile(const string& fname, + RandomAccessFile** result) = 0; + + /// \brief Creates an object that writes to a new file with the specified + /// name. + /// + /// Deletes any existing file with the same name and creates a + /// new file. On success, stores a pointer to the new file in + /// *result and returns OK. On failure stores NULL in *result and + /// returns non-OK. + /// + /// The returned file will only be accessed by one thread at a time. + virtual Status NewWritableFile(const string& fname, + WritableFile** result) = 0; + + /// \brief Creates an object that either appends to an existing file, or + /// writes to a new file (if the file does not exist to begin with). + /// + /// On success, stores a pointer to the new file in *result and + /// returns OK. On failure stores NULL in *result and returns + /// non-OK. + /// + /// The returned file will only be accessed by one thread at a time. + virtual Status NewAppendableFile(const string& fname, + WritableFile** result) = 0; + + /// Returns true iff the named file exists. + virtual bool FileExists(const string& fname) = 0; + + /// \brief Stores in *result the names of the children of the specified + /// directory. The names are relative to "dir". + /// + /// Original contents of *results are dropped. + virtual Status GetChildren(const string& dir, + std::vector* result) = 0; + + /// Deletes the named file. + virtual Status DeleteFile(const string& fname) = 0; + + /// Creates the specified directory. + virtual Status CreateDir(const string& dirname) = 0; + + /// Deletes the specified directory. + virtual Status DeleteDir(const string& dirname) = 0; + + /// Stores the size of fname in *file_size. + virtual Status GetFileSize(const string& fname, uint64* file_size) = 0; + + /// \brief Renames file src to target. If target already exists, it will be + /// replaced. + virtual Status RenameFile(const string& src, const string& target) = 0; + + // TODO(jeff,sanjay): Add back thread/thread-pool support if needed. + // TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or + // provide a routine to get the absolute time. + + /// \brief Returns the number of micro-seconds since some fixed point in + /// time. Only useful for computing deltas of time. + virtual uint64 NowMicros() = 0; + + /// Sleeps/delays the thread for the prescribed number of micro-seconds. + virtual void SleepForMicroseconds(int micros) = 0; + + /// \brief Returns a new thread that is running fn() and is identified + /// (for debugging/performance-analysis) by "name". + /// + /// Caller takes ownership of the result and must delete it eventually + /// (the deletion will block until fn() stops running). + virtual Thread* StartThread(const ThreadOptions& thread_options, + const string& name, + std::function fn) TF_MUST_USE_RESULT = 0; + + private: + /// No copying allowed + Env(const Env&); + void operator=(const Env&); +}; + +/// A file abstraction for randomly reading the contents of a file. +class RandomAccessFile { + public: + RandomAccessFile() {} + virtual ~RandomAccessFile(); + + /// \brief Reads up to "n" bytes from the file starting at "offset". + /// + /// "scratch[0..n-1]" may be written by this routine. Sets "*result" + /// to the data that was read (including if fewer than "n" bytes were + /// successfully read). May set "*result" to point at data in + /// "scratch[0..n-1]", so "scratch[0..n-1]" must be live when + /// "*result" is used. + /// + /// On OK returned status: "n" bytes have been stored in "*result". + /// On non-OK returned status: [0..n] bytes have been stored in "*result". + /// + /// Returns OUT_OF_RANGE if fewer than n bytes were stored in "*result" + /// because of EOF. + /// + /// Safe for concurrent use by multiple threads. + virtual Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const = 0; + + private: + /// No copying allowed + RandomAccessFile(const RandomAccessFile&); + void operator=(const RandomAccessFile&); +}; + +/// \brief A file abstraction for sequential writing. +/// +/// The implementation must provide buffering since callers may append +/// small fragments at a time to the file. +class WritableFile { + public: + WritableFile() {} + virtual ~WritableFile(); + + virtual Status Append(const StringPiece& data) = 0; + virtual Status Close() = 0; + virtual Status Flush() = 0; + virtual Status Sync() = 0; + + private: + /// No copying allowed + WritableFile(const WritableFile&); + void operator=(const WritableFile&); +}; + +/// \brief An implementation of Env that forwards all calls to another Env. +/// +/// May be useful to clients who wish to override just part of the +/// functionality of another Env. +class EnvWrapper : public Env { + public: + /// Initializes an EnvWrapper that delegates all calls to *t + explicit EnvWrapper(Env* t) : target_(t) {} + virtual ~EnvWrapper(); + + /// Returns the target to which this Env forwards all calls + Env* target() const { return target_; } + + // The following text is boilerplate that forwards all methods to target() + Status NewRandomAccessFile(const string& f, + RandomAccessFile** r) override { + return target_->NewRandomAccessFile(f, r); + } + Status NewWritableFile(const string& f, WritableFile** r) override { + return target_->NewWritableFile(f, r); + } + Status NewAppendableFile(const string& f, WritableFile** r) override { + return target_->NewAppendableFile(f, r); + } + bool FileExists(const string& f) override { return target_->FileExists(f); } + Status GetChildren(const string& dir, std::vector* r) override { + return target_->GetChildren(dir, r); + } + Status DeleteFile(const string& f) override { + return target_->DeleteFile(f); + } + Status CreateDir(const string& d) override { + return target_->CreateDir(d); + } + Status DeleteDir(const string& d) override { + return target_->DeleteDir(d); + } + Status GetFileSize(const string& f, uint64* s) override { + return target_->GetFileSize(f, s); + } + Status RenameFile(const string& s, const string& t) override { + return target_->RenameFile(s, t); + } + uint64 NowMicros() override { return target_->NowMicros(); } + void SleepForMicroseconds(int micros) override { + target_->SleepForMicroseconds(micros); + } + Thread* StartThread(const ThreadOptions& thread_options, const string& name, + std::function fn) override { + return target_->StartThread(thread_options, name, fn); + } + + private: + Env* target_; +}; + +class Thread { + public: + Thread() {} + + /// Blocks until the thread of control stops running. + virtual ~Thread(); + + private: + /// No copying allowed + Thread(const Thread&); + void operator=(const Thread&); +}; + +/// \brief Options to configure a Thread. +/// +/// Note that the options are all hints, and the +/// underlying implementation may choose to ignore it. +struct ThreadOptions { + /// Thread stack size to use (in bytes). + size_t stack_size = 0; // 0: use system default value + /// Guard area size to use near thread stacks to use (in bytes) + size_t guard_size = 0; // 0: use system default value +}; + +/// A utility routine: reads contents of named file into *data +Status ReadFileToString(Env* env, const string& fname, string* data); + +/// A utility routine: write contents of "data" to file named "fname" +/// (overwriting existing contents, if any). +Status WriteStringToFile(Env* env, const string& fname, + const StringPiece& data); + +/// Reads contents of named file and parse as binary encoded proto data +/// and store into *proto. +Status ReadBinaryProto(Env* env, const string& fname, + ::tensorflow::protobuf::MessageLite* proto); + +} // namespace tensorflow + +#endif // TENSORFLOW_PUBLIC_ENV_H_ diff --git a/tensorflow/core/public/session.h b/tensorflow/core/public/session.h new file mode 100644 index 0000000000..a33d5ee6ae --- /dev/null +++ b/tensorflow/core/public/session.h @@ -0,0 +1,125 @@ +#ifndef TENSORFLOW_PUBLIC_SESSION_H_ +#define TENSORFLOW_PUBLIC_SESSION_H_ + +#include +#include + +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +/// \brief A Session instance lets a caller drive a TensorFlow graph +/// computation. +/// +/// When a Session is created with a given target, a new Session object +/// is bound to the universe of resources specified by that target. +/// Those resources are available to this session to perform +/// computation described in the GraphDef. After extending the session +/// with a graph, the caller uses the Run() API to perform the +/// computation and potentially fetch outputs as Tensors. +/// +/// Example: +/// +/// tensorflow::GraphDef graph; +/// // ... Create or load graph into 'graph'. +/// +/// // This example uses the default options which connects +/// // to a local runtime. +/// tensorflow::SessionOptions options; +/// std::unique_ptr +/// session(tensorflow::NewSession(options)); +/// +/// // Create the session with this graph. +/// tensorflow::Status s = session->Create(graph); +/// if (!s.ok()) { ... } +/// +/// // Run the graph and fetch the first output of the "output" +/// // operation, and also run to but do not return anything +/// // for the "update_state" operation. +/// std::vector outputs; +/// s = session->Run({}, {"output:0"}, {"update_state"}, &outputs); +/// if (!s.ok()) { ... } +/// +/// // Map the output as a flattened float tensor, and do something +/// // with it. +/// auto output_tensor = outputs[0].flat(); +/// if (output_tensor(0) > 0.5) { ... } +/// +/// // Close the session to release the resources associated with +/// // this session. +/// session->Close() +/// +/// A Session allows concurrent calls to Run(), though a Session must +/// be created / extended by a single thread. +/// +/// Only one thread must call Close(), and Close() must only be called +/// after all other calls to Run() have returned. +class Session { + public: + /// \brief Create the graph to be used for the session. + /// + /// Returns an error if this session has already been created with a + /// graph. To re-use the session with a different graph, the caller + /// must Close() the session first. + virtual Status Create(const GraphDef& graph) = 0; + + /// \brief Adds operations to the graph that is already registered with the + /// Session. + /// + /// The names of new operations in "graph" must not exist in the + /// graph that is already registered. + virtual Status Extend(const GraphDef& graph) = 0; + + /// \brief Runs the graph with the provided input tensors and fills + /// 'outputs' for the endpoints specified in 'output_tensor_names'. + /// Runs to but does not return Tensors for the nodes in + /// 'target_node_names'. + /// + /// The order of tensors in 'outputs' will match the order provided + /// by 'output_tensor_names'. + /// + /// If Run returns OK(), then outputs->size() will be equal to + /// output_tensor_names.size(). If Run does not return OK(), the + /// state of outputs is undefined. + /// + /// REQUIRES: The name of each Tensor of the input or output must + /// match a "Tensor endpoint" in the GraphDef passed to Create(). + /// + /// REQUIRES: outputs is not nullptr if output_tensor_names is non-empty. + virtual Status Run(const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs) = 0; + + /// \brief Closes this session. + /// + /// Closing a session releases the resources used by this session + /// on the TensorFlow runtime (specified during session creation by + /// the 'SessionOptions::target' field). + virtual Status Close() = 0; + + virtual ~Session() {} +}; + +/// \brief Create a new session with the given options. +/// +/// If a new session object could not be created, this function will +/// return nullptr. +Session* NewSession(const SessionOptions& options); + +/// \brief Create a new session with the given options. +/// +/// If session creation succeeds, the new Session will be stored in +/// *out_session, the caller will take ownership of the returned +/// *out_session, and this function will return OK(). Otherwise, this +/// function will return an error status. +Status NewSession(const SessionOptions& options, Session** out_session); + +} // end namespace tensorflow + +#endif // TENSORFLOW_PUBLIC_SESSION_H_ diff --git a/tensorflow/core/public/session_options.h b/tensorflow/core/public/session_options.h new file mode 100644 index 0000000000..11d52426ac --- /dev/null +++ b/tensorflow/core/public/session_options.h @@ -0,0 +1,50 @@ +#ifndef TENSORFLOW_PUBLIC_SESSION_OPTIONS_H_ +#define TENSORFLOW_PUBLIC_SESSION_OPTIONS_H_ + +#include +#include "tensorflow/core/framework/config.pb.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class Env; + +/// Configuration information for a Session. +struct SessionOptions { + /// The environment to use. + Env* env; + + /// \brief The TensorFlow runtime to connect to. + /// + /// If 'target' is empty or unspecified, the local TensorFlow runtime + /// implementation will be used. Otherwise, the TensorFlow engine + /// defined by 'target' will be used to perform all computations. + /// + /// "target" can be either a single entry or a comma separated list + /// of entries. Each entry is a resolvable address of the + /// following format: + /// local + /// ip:port + /// host:port + /// ... other system-specific formats to identify tasks and jobs ... + /// + /// NOTE: at the moment 'local' maps to an in-process service-based + /// runtime. + /// + /// Upon creation, a single session affines itself to one of the + /// remote processes, with possible load balancing choices when the + /// "target" resolves to a list of possible processes. + /// + /// If the session disconnects from the remote process during its + /// lifetime, session calls may fail immediately. + string target; + + /// Configuration options. + ConfigProto config; + + SessionOptions(); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_PUBLIC_SESSION_OPTIONS_H_ diff --git a/tensorflow/core/public/status.h b/tensorflow/core/public/status.h new file mode 100644 index 0000000000..d0405b8876 --- /dev/null +++ b/tensorflow/core/public/status.h @@ -0,0 +1,96 @@ +#ifndef TENSORFLOW_PUBLIC_STATUS_H_ +#define TENSORFLOW_PUBLIC_STATUS_H_ + +#include +#include +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +class Status { + public: + /// Create a success status. + Status() : state_(NULL) {} + ~Status(); + + /// \brief Create a status with the specified error code and msg as a + /// human-readable string containing more detailed information. + Status(tensorflow::error::Code code, tensorflow::StringPiece msg); + + /// Copy the specified status. + Status(const Status& s); + void operator=(const Status& s); + + static Status OK() { return Status(); } + + /// Returns true iff the status indicates success. + bool ok() const { return (state_ == NULL); } + + tensorflow::error::Code code() const { + return ok() ? tensorflow::error::OK : state_->code; + } + + const string& error_message() const { + return ok() ? empty_string() : state_->msg; + } + + bool operator==(const Status& x) const; + bool operator!=(const Status& x) const; + + /// \brief If "ok()", stores "new_status" into *this. If "!ok()", preserves + /// the current status, but may augment with additional information + /// about "new_status". + /// + /// Convenient way of keeping track of the first error encountered. + /// Instead of: + /// if (overall_status.ok()) overall_status = new_status + /// Use: + /// overall_status.Update(new_status); + void Update(const Status& new_status); + + /// \brief Return a string representation of this status suitable for + /// printing. Returns the string "OK" for success. + string ToString() const; + + private: + static const string& empty_string(); + struct State { + tensorflow::error::Code code; + string msg; + }; + /// OK status has a NULL state_. Otherwise, state_ points to + /// a State structure containing the error code and message(s) + State* state_; + + void SlowCopyFrom(const State* src); +}; + +inline Status::Status(const Status& s) + : state_((s.state_ == NULL) ? NULL : new State(*s.state_)) {} + +inline void Status::operator=(const Status& s) { + /// The following condition catches both aliasing (when this == &s), + /// and the common case where both s and *this are ok. + if (state_ != s.state_) { + SlowCopyFrom(s.state_); + } +} + +inline bool Status::operator==(const Status& x) const { + return (this->state_ == x.state_) || (ToString() == x.ToString()); +} + +inline bool Status::operator!=(const Status& x) const { return !(*this == x); } + +std::ostream& operator<<(std::ostream& os, const Status& x); + +typedef std::function StatusCallback; + +#define TF_CHECK_OK(val) CHECK_EQ(::tensorflow::Status::OK(), (val)) +#define TF_QCHECK_OK(val) QCHECK_EQ(::tensorflow::Status::OK(), (val)) + +} // namespace tensorflow + +#endif // TENSORFLOW_PUBLIC_STATUS_H_ diff --git a/tensorflow/core/public/tensor.h b/tensorflow/core/public/tensor.h new file mode 100644 index 0000000000..6c6ff0f58a --- /dev/null +++ b/tensorflow/core/public/tensor.h @@ -0,0 +1,472 @@ +#ifndef TENSORFLOW_PUBLIC_TENSOR_H_ +#define TENSORFLOW_PUBLIC_TENSOR_H_ + +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_description.pb.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +class TensorBuffer; // Forward declaration. +class TensorCApi; + +/// Represents an n-dimensional array of values. +class Tensor { + public: + /// Default Tensor constructor. Creates a 1-dimension, 0-element float tensor. + Tensor(); + + /// \brief Creates a Tensor of the given datatype and shape. + /// + /// The underlying buffer is allocated using a CPUAllocator. + Tensor(DataType type, const TensorShape& shape); + + /// \brief Creates a tensor with the input datatype and shape, using the + /// allocator 'a' to allocate the underlying buffer. + /// + /// 'a' must outlive the lifetime of this Tensor. + Tensor(Allocator* a, DataType type, const TensorShape& shape); + + /// Creates an uninitialized Tensor of the given data type. + explicit Tensor(DataType type); + + Tensor(const Tensor& other); /// Copy constructor. + + ~Tensor(); + + /// Returns the data type. + DataType dtype() const { return type_; } + + /// Returns the shape of the tensor. + const TensorShape& shape() const { return shape_; } + + /// \brief Convenience accessor for the tensor shape. + /// + /// For all shape accessors, see comments for relevant methods of + /// TensorShape in tensor_shape.h. + int dims() const { return shape().dims(); } + + /// Convenience accessor for the tensor shape. + int64 dim_size(int d) const { return shape().dim_size(d); } + + /// Convenience accessor for the tensor shape. + int64 NumElements() const { return shape().num_elements(); } + + bool IsSameSize(const Tensor& b) const { + return shape().IsSameSize(b.shape()); + } + + /// Has this Tensor been initialized? + bool IsInitialized() const; + + /// Returns the estimated memory usage of this tensor. + size_t TotalBytes() const; + + /// Assign operator. This tensor shares other's underlying storage. + Tensor& operator=(const Tensor& other) { + CopyFromInternal(other, other.shape()); + return *this; + } + + /// \brief Copy the other tensor into this tensor and reshape it. + /// + /// This tensor shares other's underlying storage. Returns + /// true iff other.shape() has the same number of elements of the + /// given "shape". + bool CopyFrom(const Tensor& other, + const TensorShape& shape) TF_MUST_USE_RESULT { + if (other.NumElements() != shape.num_elements()) return false; + CopyFromInternal(other, shape); + return true; + } + + /// \brief Slice this tensor along the 1st dimension. + + /// I.e., the returned + /// tensor satisifies returned[i, ...] == this[dim0_start + i, ...]. + /// The returned tensor shares the underlying tensor buffer with this + /// tensor. + /// + /// NOTE: The returned tensor may not satisfies the same alignment + /// requirement as this tensor depending on the shape. The caller + /// must check the returned tensor's alignment before calling certain + /// methods that have alignment requirement (e.g., flat(), tensor()). + /// + /// REQUIRES: dims() >= 1 + /// REQUIRES: 0 <= dim0_start <= dim0_limit <= dim_size(0) + Tensor Slice(int64 dim0_start, int64 dim0_limit) const; + + /// \brief Parse "other' and construct the tensor. + + /// Returns true iff the + /// parsing succeeds. If the parsing fails, the state of "*this" is + /// unchanged. + bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT; + bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT; + + /// \brief Fills in "proto" with "*this" tensor's content. + /// + /// AsProtoField() fills in the repeated field for proto.dtype(), while + /// AsProtoTensorContent() encodes the content in proto.tensor_content() in a + /// compact form. + void AsProtoField(TensorProto* proto) const; + void AsProtoTensorContent(TensorProto* proto) const; + + /// \brief Return the Tensor data as an Eigen::Tensor with the type and + /// sizes of this Tensor. + /// + /// Use these methods when you know the data type and the number of + /// dimensions of the Tensor and you want an Eigen::Tensor + /// automatically sized to the Tensor sizes. The implementation check + /// fails if either type or sizes mismatch. + /// + /// Example: + /// typedef float T; + /// Tensor my_mat(...built with Shape{rows: 3, cols: 5}...); + /// auto mat = my_mat.matrix(); // 2D Eigen::Tensor, 3 x 5. + /// auto mat = my_mat.tensor(); // 2D Eigen::Tensor, 3 x 5. + /// auto vec = my_mat.vec(); // CHECK fails as my_mat is 2D. + /// auto vec = my_mat.tensor(); // CHECK fails as my_mat is 2D. + /// auto mat = my_mat.matrix();// CHECK fails as type mismatch. + template + typename TTypes::Vec vec() { + return tensor(); + } + + template + typename TTypes::Matrix matrix() { + return tensor(); + } + + template + typename TTypes::Tensor tensor(); + + /// \brief Return the Tensor data as an Eigen::Tensor of the data type and a + /// specified shape. + /// + /// These methods allow you to access the data with the dimensions + /// and sizes of your choice. You do not need to know the number of + /// dimensions of the Tensor to call them. However, they CHECK that + /// the type matches and the dimensions requested creates an + /// Eigen::Tensor with the same number of elements as the Tensor. + /// + /// Example: + /// typedef float T; + /// Tensor my_ten(...built with Shape{planes: 4, rows: 3, cols: 5}...); + /// // 1D Eigen::Tensor, size 60: + /// auto flat = my_ten.flat(); + /// // 2D Eigen::Tensor 12 x 5: + /// auto inner = my_ten.flat_inner_dims(); + /// // 2D Eigen::Tensor 4 x 15: + /// auto outer = my_ten.shaped({4, 15}); + /// // CHECK fails, bad num elements: + /// auto outer = my_ten.shaped({4, 8}); + /// // 3D Eigen::Tensor 6 x 5 x 2: + /// auto weird = my_ten.shaped({6, 5, 2}); + /// // CHECK fails, type mismatch: + /// auto bad = my_ten.flat(); + template + typename TTypes::Flat flat() { + return shaped({NumElements()}); + } + + template + typename TTypes::UnalignedFlat unaligned_flat() { + return unaligned_shaped({NumElements()}); + } + + /// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all + /// Tensor dimensions but the last one into the first dimension of the result. + template + typename TTypes::Matrix flat_inner_dims() { + int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1; + if (last_size == 0) { + DCHECK_EQ(NumElements(), 0); + // Return something empty, avoiding divide by 0 + return shaped({0, 0}); + } else { + return shaped({NumElements() / last_size, last_size}); + } + } + + /// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all + /// Tensor dimensions but the first one into the last dimension of the result. + template + typename TTypes::Matrix flat_outer_dims() { + int64 first_size = dims() > 0 ? dim_size(0) : 1; + if (first_size == 0) { + DCHECK_EQ(NumElements(), 0); + // Return something empty, avoiding divide by 0 + return shaped({0, 0}); + } else { + return shaped({first_size, NumElements() / first_size}); + } + } + + template + typename TTypes::Tensor shaped(gtl::ArraySlice new_sizes); + + template + typename TTypes::UnalignedTensor unaligned_shaped( + gtl::ArraySlice new_sizes); + + /// \brief Return the Tensor data as a Tensor Map of fixed size 1: + /// TensorMap>. + + /// Using scalar() allows the compiler to + /// perform optimizations as the size of the tensor is known at compile time. + template + typename TTypes::Scalar scalar(); + + /// Const versions of all the methods above. + template + typename TTypes::ConstVec vec() const { + return tensor(); + } + + template + typename TTypes::ConstMatrix matrix() const { + return tensor(); + } + + template + typename TTypes::ConstTensor tensor() const; + + template + typename TTypes::ConstFlat flat() const { + return shaped({NumElements()}); + } + + template + typename TTypes::UnalignedConstFlat unaligned_flat() const { + return unaligned_shaped({NumElements()}); + } + + template + typename TTypes::ConstMatrix flat_inner_dims() const { + int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1; + if (last_size == 0) { + DCHECK_EQ(NumElements(), 0); + // Return something empty, avoiding divide by 0 + return shaped({0, 0}); + } else { + return shaped({NumElements() / last_size, last_size}); + } + } + + template + typename TTypes::ConstMatrix flat_outer_dims() const { + int64 first_size = dims() > 0 ? dim_size(0) : 1; + if (first_size == 0) { + DCHECK_EQ(NumElements(), 0); + // Return something empty, avoiding divide by 0 + return shaped({0, 0}); + } else { + return shaped({first_size, NumElements() / first_size}); + } + } + + template + typename TTypes::ConstTensor shaped( + gtl::ArraySlice new_sizes) const; + template + typename TTypes::UnalignedConstTensor unaligned_shaped( + gtl::ArraySlice new_sizes) const; + + template + typename TTypes::ConstScalar scalar() const; + + /// Render the first max_entries values in *this into a string. + string SummarizeValue(int64 max_entries) const; + + /// A human-readable summary of the Tensor suitable for debugging. + string DebugString() const; + + /// Fill in the TensorDescription proto with metadata about the + /// Tensor that is useful for monitoring and debugging. + void FillDescription(TensorDescription* description) const; + + /// \brief Returns a StringPiece mapping the current tensor's buffer. + /// + /// The returned StringPiece may point to memory location on devices + /// that the CPU cannot address directly. + /// + /// NOTE: The underlying Tensor buffer is refcounted, so the lifetime + /// of the contents mapped by the StringPiece matches the lifetime of + /// the buffer; callers should arrange to make sure the buffer does + /// not get destroyed while the StringPiece is still used. + /// + /// REQUIRES: DataTypeCanUseMemcpy(dtype()). + StringPiece tensor_data() const; + + private: + DataType type_; + TensorShape shape_; + TensorBuffer* buf_; + + friend class DMAHelper; + friend class TensorCApi; + friend class VariableOp; // For access to set_shape + friend class AutoReloadVariableOp; // For access to set_shape + + // Creates a tensor with the input datatype, shape and buf. + // + // Acquires a ref on buf that belongs to this Tensor. + Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf); + + bool CanUseDMA() const; + + // Only needed by variable op to set the shape of an uninitialized + // Tensor. + // TODO: Remove this when we have a better story for detecting + // uninitialized tensors. + void set_shape(const TensorShape& shape) { shape_ = shape; } + + void CopyFromInternal(const Tensor& other, const TensorShape& shape); + + template + T* base() const; +}; + +// Implementation details + +// Interface to access the raw ref-counted data buffer. +class TensorBuffer : public core::RefCounted { + public: + ~TensorBuffer() override {} + + // data() points to a memory region of size() bytes. + virtual void* data() const = 0; + virtual size_t size() const = 0; + + // If this TensorBuffer is sub-buffer of another TensorBuffer, + // returns that TensorBuffer. Otherwise, returns this. + virtual TensorBuffer* root_buffer() = 0; + + // Fill metadata about the allocation into the proto. + virtual void FillAllocationDescription( + AllocationDescription* proto) const = 0; + + template + T* base() const { + return reinterpret_cast(data()); + } +}; + +inline void CheckEigenAlignment(const void* ptr) { +#if EIGEN_ALIGN == 1 + CHECK_EQ(reinterpret_cast(ptr) % EIGEN_ALIGN_BYTES, 0); +#endif +} + +template +T* Tensor::base() const { + return buf_ == nullptr ? nullptr : buf_->base(); +} + +template +typename TTypes::Tensor Tensor::tensor() { + CHECK_EQ(dtype(), DataTypeToEnum::v()); + CheckEigenAlignment(base()); + return typename TTypes::Tensor(base(), + shape().AsEigenDSizes()); +} + +template +typename TTypes::ConstTensor Tensor::tensor() const { + CheckEigenAlignment(base()); + CHECK_EQ(dtype(), DataTypeToEnum::v()); + return typename TTypes::ConstTensor(base(), + shape().AsEigenDSizes()); +} + +template +typename TTypes::Tensor Tensor::shaped( + gtl::ArraySlice new_sizes) { + CheckEigenAlignment(base()); + CHECK_EQ(dtype(), DataTypeToEnum::v()); + CHECK_EQ(NDIMS, new_sizes.size()); + int64 new_num_elements = 1; + Eigen::array dims; + for (int d = 0; d < NDIMS; d++) { + new_num_elements *= new_sizes[d]; + dims[d] = new_sizes[d]; + } + CHECK_EQ(new_num_elements, NumElements()); + return typename TTypes::Tensor(base(), dims); +} + +template +typename TTypes::UnalignedTensor Tensor::unaligned_shaped( + gtl::ArraySlice new_sizes) { + CHECK_EQ(dtype(), DataTypeToEnum::v()); + CHECK_EQ(NDIMS, new_sizes.size()); + int64 new_num_elements = 1; + Eigen::array dims; + for (int d = 0; d < NDIMS; d++) { + new_num_elements *= new_sizes[d]; + dims[d] = new_sizes[d]; + } + CHECK_EQ(new_num_elements, NumElements()); + return typename TTypes::UnalignedTensor(base(), dims); +} + +template +typename TTypes::ConstTensor Tensor::shaped( + gtl::ArraySlice new_sizes) const { + CheckEigenAlignment(base()); + CHECK_EQ(dtype(), DataTypeToEnum::v()); + CHECK_EQ(NDIMS, new_sizes.size()); + int64 new_num_elements = 1; + Eigen::array dims; + for (int d = 0; d < NDIMS; d++) { + new_num_elements *= new_sizes[d]; + dims[d] = new_sizes[d]; + } + CHECK_EQ(new_num_elements, NumElements()); + return typename TTypes::ConstTensor(base(), dims); +} + +template +typename TTypes::UnalignedConstTensor Tensor::unaligned_shaped( + gtl::ArraySlice new_sizes) const { + CHECK_EQ(dtype(), DataTypeToEnum::v()); + CHECK_EQ(NDIMS, new_sizes.size()); + int64 new_num_elements = 1; + Eigen::array dims; + for (int d = 0; d < NDIMS; d++) { + new_num_elements *= new_sizes[d]; + dims[d] = new_sizes[d]; + } + CHECK_EQ(new_num_elements, NumElements()); + return typename TTypes::UnalignedConstTensor(base(), dims); +} + +template +typename TTypes::Scalar Tensor::scalar() { + CheckEigenAlignment(base()); + CHECK_EQ(1, NumElements()) << "Must have a one element tensor"; + return typename TTypes::Scalar(base()); +} + +template +typename TTypes::ConstScalar Tensor::scalar() const { + CheckEigenAlignment(base()); + CHECK_EQ(1, NumElements()) << "Must have a one element tensor"; + return typename TTypes::ConstScalar(base()); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_PUBLIC_TENSOR_H_ diff --git a/tensorflow/core/public/tensor_c_api.h b/tensorflow/core/public/tensor_c_api.h new file mode 100644 index 0000000000..fe1846319e --- /dev/null +++ b/tensorflow/core/public/tensor_c_api.h @@ -0,0 +1,243 @@ +// TODO(jeff,sanjay): Rename to tensorflow/public/c_api.h +#ifndef TENSORFLOW_PUBLIC_TENSOR_C_API_H_ +#define TENSORFLOW_PUBLIC_TENSOR_C_API_H_ + +#include + +// -------------------------------------------------------------------------- +// C API for TensorFlow. +// +// The API leans towards simplicity and uniformity instead of convenience +// since most usage will be by language specific wrappers. +// +// Conventions: +// * We use the prefix TF_ for everything in the API. +// * Objects are always passed around as pointers to opaque structs +// and these structs are allocated/deallocated via the API. +// * TF_Status holds error information. It is an object type +// and threfore is passed around as a pointer to an opaque +// struct as mentioned above. +// * Every call that has a TF_Status* argument clears it on success +// and fills it with error info on failure. +// +// Questions left to address: +// * Might need to add stride info to TF_Tensor? +// * Might at some point need a way for callers to provide their own Env. +// * Should we remove the TF_Status arg from TF_AddProto calls and only +// report errors later (e.g., on Run call). +// * Should dimensions be unsigned instead of signed? +// * Maybe add TF_TensorShape that encapsulates dimension info. +// +// Design decisions made: +// * Backing store for tensor memory has an associated deallocation +// function. This deallocation function will point to client code +// for tensors populated by the client. So the client can do things +// like shadowing a numpy array. +// * We do not provide TF_OK since it is not strictly necessary and we +// are not optimizing for convenience. +// * We make assumption that one session has one graph. This should be +// fine since we have the ability to run sub-graphs. +// * We are not providing TF_AddNode/TF_AddNodes to better support +// languages/platforms where proto is not available. This is because +// we can just point authors of bindings at the .proto file and the +// proto serialization spec and they can do the right thing for +// their language. +// * We could allow NULL for some arguments (e.g., NULL options arg). +// However since convenience is not a primary goal, we don't do this. +// * Devices are not in this API. Instead, they are created/used internally +// and the API just provides high level controls over the number of +// devices of each type. + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. +// The enum values here are identical to corresponding values in types.proto. +typedef enum { + TF_FLOAT = 1, + TF_DOUBLE = 2, + TF_INT32 = 3, // Int32 tensors are always in 'host' memory. + TF_UINT8 = 4, + TF_INT16 = 5, + TF_INT8 = 6, + TF_STRING = 7, + TF_COMPLEX = 8, // Single-precision complex + TF_INT64 = 9, + TF_BOOL = 10, + TF_QINT8 = 11, // Quantized int8 + TF_QUINT8 = 12, // Quantized uint8 + TF_QINT32 = 13, // Quantized int32 + TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops. +} TF_DataType; + +// -------------------------------------------------------------------------- +// TF_Code holds an error code. The enum values here are identical to +// corresponding values in error_codes.proto. +typedef enum { + TF_OK = 0, + TF_CANCELLED = 1, + TF_UNKNOWN = 2, + TF_INVALID_ARGUMENT = 3, + TF_DEADLINE_EXCEEDED = 4, + TF_NOT_FOUND = 5, + TF_ALREADY_EXISTS = 6, + TF_PERMISSION_DENIED = 7, + TF_UNAUTHENTICATED = 16, + TF_RESOURCE_EXHAUSTED = 8, + TF_FAILED_PRECONDITION = 9, + TF_ABORTED = 10, + TF_OUT_OF_RANGE = 11, + TF_UNIMPLEMENTED = 12, + TF_INTERNAL = 13, + TF_UNAVAILABLE = 14, + TF_DATA_LOSS = 15, +} TF_Code; + +// -------------------------------------------------------------------------- +// TF_Status holds error information. It either has an OK code, or +// else an error code with an associated error message. +typedef struct TF_Status TF_Status; + +// Return a new status object. +extern TF_Status* TF_NewStatus(); + +// Delete a previously created status object. +extern void TF_DeleteStatus(TF_Status*); + +// Record in *s. Any previous information is lost. +// A common use is to clear a status: TF_SetStatus(s, TF_OK, ""); +extern void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg); + +// Return the code record in *s. +extern TF_Code TF_GetCode(const TF_Status* s); + +// Return a pointer to the error message in *s. The return value +// points to memory that is only usable until the next mutation to *s. +// Always returns an empty string if TF_GetCode(s) is TF_OK. +extern const char* TF_Message(const TF_Status* s); + +// -------------------------------------------------------------------------- +// TF_Tensor holds a multi-dimensional array of elements of a single data type. +// For all types other than TF_STRING, the data buffer stores elements +// in row major order. E.g. if data is treated as a vector of TF_DataType: +// +// element 0: index (0, ..., 0) +// element 1: index (0, ..., 1) +// ... +// +// TODO(jeff,sanjay): Define format for TF_STRING tensors. Perhaps: +// start_offset: array[uint64] +// data: byte[...] +// +// String length is encoded (varint?) starting at data[start_offset[i]] +// String contents follow immediately after string length. + +typedef struct TF_Tensor TF_Tensor; + +// Return a new tensor that holds the bytes data[0,len-1]. +// +// The data will be deallocated by a subsequent call to TF_DeleteTensor via: +// (*deallocator_fn)(data, len, deallocator_arg) +// Clients can provide a custom deallocator function so they can pass in +// memory managed by something like numpy. +extern TF_Tensor* TF_NewTensor(TF_DataType, long long* dims, int num_dims, + void* data, size_t len, + void (*deallocator)(void* data, size_t len, + void* arg), + void* deallocator_arg); + +// Destroy a tensor. +extern void TF_DeleteTensor(TF_Tensor*); + +// Return the type of a tensor element. +extern TF_DataType TF_TensorType(const TF_Tensor*); + +// Return the number of dimensions that the tensor has. +extern int TF_NumDims(const TF_Tensor*); + +// Return the length of the tensor in the "dim_index" dimension. +// REQUIRES: 0 <= dim_index < TF_NumDims(tensor) +extern long long TF_Dim(const TF_Tensor* tensor, int dim_index); + +// Return the size of the underlying data in bytes. +extern size_t TF_TensorByteSize(const TF_Tensor*); + +// Return a pointer to the underlying data buffer. +extern void* TF_TensorData(const TF_Tensor*); + +// -------------------------------------------------------------------------- +// TF_SessionOptions holds options that can be passed during session creation. +typedef struct TF_SessionOptions TF_SessionOptions; + +// Return a new options object. +extern TF_SessionOptions* TF_NewSessionOptions(); + +// Set the target in TF_SessionOptions.options. +// target can be empty, a single entry, or a comma separated list of entries. +// Each entry is in one of the following formats : +// "local" +// ip:port +// host:port +extern void TF_SetTarget(TF_SessionOptions* options, const char* target); + +// Set the config in TF_SessionOptions.options. +// config should be a serialized brain.ConfigProto proto. +// If config was not parsed successfully as a ConfigProto, record the +// error information in *status. +extern void TF_SetConfig(TF_SessionOptions* options, const char* config, + size_t config_len, TF_Status* status); + +// Destroy an options object. +extern void TF_DeleteSessionOptions(TF_SessionOptions*); + +// TODO(jeff,sanjay): +// - export functions to set Config fields + +// -------------------------------------------------------------------------- +// TF_Session manages a single graph and execution. +typedef struct TF_Session TF_Session; + +// Return a new execution session, or NULL on error. +extern TF_Session* TF_NewSession(const TF_SessionOptions*, TF_Status* status); + +// Close a session. +extern void TF_CloseSession(TF_Session*, TF_Status* status); + +// Destroy a session. Even if error information is recorded in *status, +// this call discards all resources associated with the session. +extern void TF_DeleteSession(TF_Session*, TF_Status* status); + +// Treat the bytes proto[0,proto_len-1] as a serialized GraphDef and +// add the nodes in that GraphDef to the graph for the session. +extern void TF_ExtendGraph(TF_Session*, const void* proto, size_t proto_len, + TF_Status*); + +// Run the graph associated with the session starting with the +// supplied inputs (inputs[0,ninputs-1]). Regardless of success or +// failure, inputs[] become the property of the implementation (the +// implementation will eventually call TF_DeleteTensor on each input). +// +// On success, the tensors corresponding to output_names[0,noutputs-1] +// are placed in outputs[]. and these outputs[] become the property +// of the caller (the caller must eventually call TF_DeleteTensor on +// them). +// +// On failure, outputs[] contains nulls. +extern void TF_Run(TF_Session*, + // Input tensors + const char** input_names, TF_Tensor** inputs, int ninputs, + // Output tensors + const char** output_tensor_names, TF_Tensor** outputs, + int noutputs, + // Target nodes + const char** target_node_names, int ntargets, + // Output status + TF_Status*); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_PUBLIC_TENSOR_C_API_H_ diff --git a/tensorflow/core/public/tensor_shape.h b/tensorflow/core/public/tensor_shape.h new file mode 100644 index 0000000000..a889b8b17d --- /dev/null +++ b/tensorflow/core/public/tensor_shape.h @@ -0,0 +1,239 @@ +#ifndef TENSORFLOW_PUBLIC_TENSOR_SHAPE_H_ +#define TENSORFLOW_PUBLIC_TENSOR_SHAPE_H_ + +#include + +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +class TensorShapeIter; // Declared below + +/// Manages the dimensions of a Tensor and their sizes. +class TensorShape { + public: + /// \brief Construct a TensorShape from the provided sizes.. + /// REQUIRES: dim_sizes[i] >= 0 + explicit TensorShape(gtl::ArraySlice dim_sizes); + TensorShape(std::initializer_list dim_sizes) + : TensorShape(gtl::ArraySlice(dim_sizes)) {} + + /// REQUIRES: IsValid(proto) + explicit TensorShape(const TensorShapeProto& proto); + + /// Create a tensor shape with no dimensions and one element, which you can + /// then call AddDim() on. + TensorShape(); + + /// Returns true iff "proto" is a valid tensor shape. + static bool IsValid(const TensorShapeProto& proto); + + /// Clear a tensor shape + void Clear(); + + /// \brief Add a dimension to the end ("inner-most"). + /// REQUIRES: size >= 0 + void AddDim(int64 size); + + /// Appends all the dimensions from shape. + void AppendShape(const TensorShape& shape); + + /// \brief Insert a dimension somewhere in the TensorShape. + /// REQUIRES: "0 <= d <= dims()" + /// REQUIRES: size >= 0 + void InsertDim(int d, int64 size); + + /// \brief Modifies the size of the dimension 'd' to be 'size' + /// REQUIRES: "0 <= d < dims()" + /// REQUIRES: size >= 0 + void set_dim(int d, int64 size); + + /// \brief Removes dimension 'd' from the TensorShape. + /// REQUIRES: "0 <= d < dims()" + void RemoveDim(int d); + + /// Return the number of dimensions in the tensor. + int dims() const { return dim_sizes_.size(); } + + /// \brief Returns the number of elements in dimension "d". + /// REQUIRES: "0 <= d < dims()" + // TODO(mdevin): Rename to dimension() to match Eigen::Tensor::dimension()? + int64 dim_size(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return dim_sizes_[d]; + } + + /// Returns sizes of all dimensions. + gtl::ArraySlice dim_sizes() const { return dim_sizes_; } + + /// \brief Returns the number of elements in the tensor. + /// + /// We use int64 and + /// not size_t to be compatible with Eigen::Tensor which uses ptr_fi + int64 num_elements() const { return num_elements_; } + + /// Returns true if *this and b have the same sizes. Ignores dimension names. + bool IsSameSize(const TensorShape& b) const; + bool operator==(const TensorShape& b) const { return IsSameSize(b); } + + /// Fill *proto from *this. + void AsProto(TensorShapeProto* proto) const; + + /// Fill *dsizes from *this. + template + Eigen::DSizes AsEigenDSizes() const; + + /// Same as AsEigenDSizes() but allows for NDIMS > dims() -- in which case we + /// pad the rest of the sizes with 1. + template + Eigen::DSizes AsEigenDSizesWithPadding() const; + + /// For iterating through the dimensions. + TensorShapeIter begin() const; + TensorShapeIter end() const; + + /// For error messages. + string DebugString() const; + // TODO(vrv): Remove this, this is the same as DebugString(). + string ShortDebugString() const; + + private: + /// Recalculates the dimensions of this tensor after they are modified. + void recompute_dims(); + + // TODO(josh11b): Maybe use something from the Eigen Tensor library + /// for the sizes. + gtl::InlinedVector dim_sizes_; + + /// total number of elements (avoids recomputing it each time). + int64 num_elements_; +}; + +struct TensorShapeDim { + explicit TensorShapeDim(int64 s) : size(s) {} + int size; +}; + +class TensorShapeIter { + public: + TensorShapeIter(const TensorShape* shape, int d) : shape_(shape), d_(d) {} + bool operator==(const TensorShapeIter& rhs) { + DCHECK(shape_ == rhs.shape_); + return d_ == rhs.d_; + } + bool operator!=(const TensorShapeIter& rhs) { + DCHECK(shape_ == rhs.shape_); + return d_ != rhs.d_; + } + void operator++() { ++d_; } + TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); } + + private: + const TensorShape* shape_; + int d_; +}; + +// In some places, allow shape (1,) to be treated as a scalar and shape () to be +// treated as a vector. This flag is for temporary backwards compatibility +// only, and will be changed to strict within Google around November 15, 2015. +#if defined(PLATFORM_GOOGLE) +// TODO(irving): Become strict on November 15, 2015. +static const bool kAllowLegacyScalars = true; +#else +// For open source (outside Google), we are strict. +static const bool kAllowLegacyScalars = false; +#endif + +/// \brief Static helper routines for TensorShape. Includes a few common +/// predicates on a tensor shape. +class TensorShapeUtils { + public: + static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; } + + static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; } + + // Allow either scalars or (if allowing legacy scalars) shape (1,). + static bool IsLegacyScalar(const TensorShape& shape) { + return shape.dims() == 0 || + (kAllowLegacyScalars && shape.dims() == 1 && shape.dim_size(0) == 1); + } + + // Allow rank 1 or (if allowing legacy scalars) rank 0. + static bool IsLegacyVector(const TensorShape& shape) { + return shape.dims() == 1 || (kAllowLegacyScalars && shape.dims() == 0); + } + + static bool IsVectorOrHigher(const TensorShape& shape) { + return shape.dims() >= 1; + } + + static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; } + + static bool IsMatrixOrHigher(const TensorShape& shape) { + return shape.dims() >= 2; + } + + /// \brief Returns a TensorShape whose dimensions are dims[0], dims[1], ..., + /// dims[n-1]. + template + static TensorShape MakeShape(const T* dims, int n) { + TensorShape shape; + for (int i = 0; i < n; ++i) shape.AddDim(dims[i]); + return shape; + } + + static string ShapeListString(const gtl::ArraySlice& shapes) { + string result = "["; + bool first = true; + for (const TensorShape& shape : shapes) { + strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString()); + first = false; + } + strings::StrAppend(&result, "]"); + return result; + } + + static bool StartsWith(const TensorShape& shape0, const TensorShape& shape1); +}; + +// TODO(josh11b): Add TensorStrides once we support strides +// struct TensorStrides { +// gtl::InlinedVector strides_; +// }; + +// ---------------------------------------------------------------------------- +// Template method implementation details below +// ---------------------------------------------------------------------------- + +template +Eigen::DSizes TensorShape::AsEigenDSizes() const { + CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS + << " for a tensor of " << dims() << " dimensions"; + return AsEigenDSizesWithPadding(); +} + +template +Eigen::DSizes TensorShape::AsEigenDSizesWithPadding() + const { + CHECK_GE(NDIMS, dims()) << "Asking for tensor of " << NDIMS + << " for a tensor of " << dims() << " dimensions"; + Eigen::DSizes dsizes; + for (int d = 0; d < dims(); d++) { + dsizes[d] = dim_size(d); + } + for (int d = dims(); d < NDIMS; d++) { + dsizes[d] = 1; + } + return dsizes; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_PUBLIC_TENSOR_SHAPE_H_ diff --git a/tensorflow/core/public/tensorflow_server.h b/tensorflow/core/public/tensorflow_server.h new file mode 100644 index 0000000000..0dac414555 --- /dev/null +++ b/tensorflow/core/public/tensorflow_server.h @@ -0,0 +1,19 @@ +#ifndef TENSORFLOW_PUBLIC_TENSORFLOW_SERVER_H_ +#define TENSORFLOW_PUBLIC_TENSORFLOW_SERVER_H_ + +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Initialize the TensorFlow service for this address space. +// This is a blocking call that never returns. +// See BUILD file for details on linkage guidelines. +::tensorflow::Status InitTensorFlow(); + +// Like InitTensorFlow() but returns after the Tensorflow +// services have been launched. +::tensorflow::Status LaunchTensorFlow(); + +} // namespace tensorflow + +#endif // TENSORFLOW_PUBLIC_TENSORFLOW_SERVER_H_ diff --git a/tensorflow/core/user_ops/fact.cc b/tensorflow/core/user_ops/fact.cc new file mode 100644 index 0000000000..7b6932244d --- /dev/null +++ b/tensorflow/core/user_ops/fact.cc @@ -0,0 +1,29 @@ +// An example Op. + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +using namespace tensorflow; + +REGISTER_OP("Fact") + .Output("fact: string") + .Doc(R"doc( +Output a fact about factorials. +)doc"); + +class FactOp : public OpKernel { + public: + explicit FactOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Output a scalar string. + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape(), &output_tensor)); + auto output = output_tensor->template scalar(); + + output() = "0! == 1"; + } +}; + +REGISTER_KERNEL_BUILDER(Name("Fact").Device(DEVICE_CPU), FactOp); diff --git a/tensorflow/core/util/bcast.cc b/tensorflow/core/util/bcast.cc new file mode 100644 index 0000000000..4e70b78751 --- /dev/null +++ b/tensorflow/core/util/bcast.cc @@ -0,0 +1,120 @@ +#include "tensorflow/core/util/bcast.h" + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { + +/* static */ +void BCast::Reverse(Vec* shape) { std::reverse(shape->begin(), shape->end()); } + +BCast::BCast(const Vec& sx, const Vec& sy) { + // Reverse the shape of x and y for convenience. + // After the reverse, 0-th is the inner-most dimension. + Vec x = sx; + Reverse(&x); + Vec y = sy; + Reverse(&y); + + // 1-extend and align x and y so that they are the same size. + if (x.size() > y.size()) { + y.resize(x.size(), 1); + } else { + x.resize(y.size(), 1); + } + + // Going through each dimension starting from the inner-most + // dimension, compares dimension of x and y. They are compatible if + // they are equal or either is 1. + enum State { + UNKNOWN, + SAME, + X_ONE, + Y_ONE, + }; + State prev = UNKNOWN; + const int64 n = x.size(); + for (int i = 0; i < n; ++i) { + // Output shape. + State curr = UNKNOWN; + const int64 x_i = x[i]; // i-th dimension of x. + CHECK_GE(x_i, 0); + const int64 y_i = y[i]; // i-th dimension of y. + CHECK_GE(y_i, 0); + int64 o_i; // i-th dimension of the output. + int64 bx_i; // i-th broadcast for x. + int64 by_i; // i-th broadcast for y. + // Invariant: + // o_i = x_i * bx_i = y_i * by_i + if (x_i == y_i) { + // No broadcast. + o_i = x_i; + bx_i = 1; + by_i = 1; + curr = SAME; + } else if (x_i == 1) { + // x broadcast to y on this dimension. + o_i = y_i; + bx_i = y_i; + by_i = 1; + grad_x_reduce_idx_.push_back(n - 1 - i); + curr = X_ONE; + } else if (y_i == 1) { + // y broadcast to x on this dimension. + o_i = x_i; + bx_i = 1; + by_i = x_i; + grad_y_reduce_idx_.push_back(n - 1 - i); + curr = Y_ONE; + } else { + valid_ = false; + return; + } + output_.push_back(o_i); + // Reshape/broadcast. + // Invariant: + // result[i] == x_reshape[i] * x_bcast[i] == y_reshape_[i] * y_bcast_[i] + if (curr == SAME && x_i == 1) { + // Both side are 1s. + grad_x_reduce_idx_.push_back(n - 1 - i); + grad_y_reduce_idx_.push_back(n - 1 - i); + continue; + } else if (prev == curr) { + // It is a run of the same cases (no broadcast, x broadcast to + // y, y broadcast to x). We can reshape the input so that fewer + // dimensions are involved in the intermediate computation. + result_.back() *= o_i; + x_reshape_.back() *= x_i; + x_bcast_.back() *= bx_i; + y_reshape_.back() *= y_i; + y_bcast_.back() *= by_i; + } else { + result_.push_back(o_i); + x_reshape_.push_back(x_i); + x_bcast_.push_back(bx_i); + y_reshape_.push_back(y_i); + y_bcast_.push_back(by_i); + } + prev = curr; + } + + if (result_.empty()) { + // Can happen when both x and y are effectively scalar. + result_.push_back(1); + x_reshape_.push_back(1); + x_bcast_.push_back(1); + y_reshape_.push_back(1); + y_bcast_.push_back(1); + } + + // Reverse all vectors since x and y were reversed at very + // beginning. + Reverse(&x_reshape_); + Reverse(&x_bcast_); + Reverse(&y_reshape_); + Reverse(&y_bcast_); + Reverse(&result_); + Reverse(&output_); + Reverse(&grad_x_reduce_idx_); + Reverse(&grad_y_reduce_idx_); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h new file mode 100644 index 0000000000..9f0233e415 --- /dev/null +++ b/tensorflow/core/util/bcast.h @@ -0,0 +1,99 @@ +#ifndef TENSORFLOW_UTIL_BCAST_H_ +#define TENSORFLOW_UTIL_BCAST_H_ + +#include +#include + +#include "tensorflow/core/platform/port.h" + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { + +// BCast is a helper for broadcasting binary tensor operation. +// TensorFlow's broadcasting rule follows that of numpy (See +// http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). +// +// The rule has the following properties: +// +// 1. suffix matching: the rule starts with the right-most +// dimension, and works towards the left-most dimension. Since +// TensorFlow is row-major, the right-most dimension (the last +// element in the shape of a tensor) is the inner-most, a.k.a. +// the fastest changing, dimension. +// +// 2. Two dimensions are compatible for broadcasting if both are the +// same or either is 1. +// +// BCast takes the shape of two tensors and computes a few vectors of +// int32 that are useful for the caller to reshape the tensors, apply +// the right broadcasts to them, compute the broadcasted operation, +// and possibly the gradients. In a nutshell, the caller is expected +// to compute the broadcasted operation as following: +// +// BCast b(x.shape(), y.shape()); +// output = x.reshape(b.x_reshape()).broadcast(b.x_bcast()) +// _op_ +// y.reshape(b.y_reshape()).broadcast(b.y_bcast()) +// +// For the gradient computation, +// grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx) +// .reshape(x.shape()) +// grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx) +// .reshape(y.shape()) +// backprop_x and backprop_y are functionals of the binary function "op", +// e.g., +// for +, backprop_x(x, y) = backprop_y(x, y) = 1; +// for *, backprop_x(x, y) = y, backprop_y(x, y) = x; +// for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2; +// +// The multiplication in the grad * backprop_x itself is also +// broadcasting following the same rule. +// +// TODO(zhifengc): Adds support for n-ary (n >= 2). +class BCast { + public: + // A vector of int32 representing the shape of tensor. The 0-th + // element is the outer-most dimension and the last element is the + // inner-most dimension. Note that we do not use TensorShape since + // it's more convenient to manipulate Vec directly for this module. + typedef std::vector Vec; + + BCast(const Vec& x, const Vec& y); + ~BCast() {} + + // Returns true iff two operands are compatible according to the + // broadcasting rule. + bool IsValid() const { return valid_; } + + // If and only if IsValid(), the following fields can be used in + // implementing a broadcasted binary tensor operation according to + // the broadcasting rule. + const Vec& x_reshape() const { return x_reshape_; } + const Vec& x_bcast() const { return x_bcast_; } + const Vec& y_reshape() const { return y_reshape_; } + const Vec& y_bcast() const { return y_bcast_; } + const Vec& result_shape() const { return result_; } + const Vec& output_shape() const { return output_; } + const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; } + const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; } + + private: + bool valid_ = true; + Vec x_reshape_; + Vec x_bcast_; + Vec y_reshape_; + Vec y_bcast_; + Vec result_; + Vec output_; + Vec grad_x_reduce_idx_; + Vec grad_y_reduce_idx_; + + static void Reverse(Vec* shape); + static bool HasZero(const Vec& shape); + + TF_DISALLOW_COPY_AND_ASSIGN(BCast); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_UTIL_BCAST_H_ diff --git a/tensorflow/core/util/bcast_test.cc b/tensorflow/core/util/bcast_test.cc new file mode 100644 index 0000000000..02d18586d6 --- /dev/null +++ b/tensorflow/core/util/bcast_test.cc @@ -0,0 +1,226 @@ +#include "tensorflow/core/util/bcast.h" + +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include + +namespace tensorflow { +namespace { + +string BCast(const tensorflow::BCast::Vec& x, const tensorflow::BCast::Vec& y) { + tensorflow::BCast b(x, y); + if (!b.IsValid()) { + return "invalid"; + } + string ret; + strings::StrAppend(&ret, "[", str_util::Join(b.x_reshape(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.x_bcast(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.y_reshape(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.y_bcast(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.result_shape(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.output_shape(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.grad_x_reduce_idx(), ","), + "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.grad_y_reduce_idx(), ","), + "]"); + return ret; +} + +TEST(BCastTest, Invalid) { + EXPECT_EQ("invalid", BCast({5, 3, 2}, {3})); + EXPECT_EQ("invalid", BCast({5, 3, 2}, {2, 2})); + EXPECT_EQ("invalid", BCast({5, 3, 2}, {10, 1, 1})); + EXPECT_EQ("invalid", BCast({1, 2, 1, 2, 1, 2}, {2, 4, 2, 1, 2, 1})); +} + +TEST(BCastTest, Basic_SameShape) { + // Effectively no broadcast needed. + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}), + "[2310][1][2310][1]" + "[2310]" + "[11,7,5,3,2]" + "[][]"); +} + +TEST(BCastTest, Basic_Scalar_Scalar) { + // Effectively it's a scalar and a scalar. + // [1, 1] [1] + EXPECT_EQ(BCast({1, 1}, {1}), + "[1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1]"); + + // [1] [1, 1] + EXPECT_EQ(BCast({1}, {1, 1}), + "[1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1]"); +} + +TEST(BCastTest, Basic_Tensor_Scalar) { + // Effectively it's a tensor and a scalar. + // [11, 7, 5, 3, 2] [1] + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {1}), + "[2310][1][1][2310]" + "[2310]" + "[11,7,5,3,2]" + "[][0,1,2,3,4]"); + + // [1] [11, 7, 5, 3, 2] + EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2}), + "[1][2310][2310][1]" + "[2310]" + "[11,7,5,3,2]" + "[0,1,2,3,4][]"); +} + +TEST(BCastTest, Basic_Tensor_With_DimSize_1_Scalar) { + // Effectively it's a tensor and a scalar. + // [11, 7, 5, 3, 2, 1] [1] + EXPECT_EQ(BCast({11, 7, 5, 3, 2, 1}, {1}), + "[2310][1][1][2310]" + "[2310]" + "[11,7,5,3,2,1]" + "[5][0,1,2,3,4,5]"); + + // [1] [11, 7, 5, 3, 2, 1] + EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2, 1}), + "[1][2310][2310][1]" + "[2310]" + "[11,7,5,3,2,1]" + "[0,1,2,3,4,5][5]"); + + // Effectively it's a tensor and a scalar. + // [11, 7, 5, 1, 1, 3, 2, 1] [1] + EXPECT_EQ(BCast({11, 7, 5, 1, 1, 3, 2, 1, 1}, {1}), + "[2310][1][1][2310]" + "[2310]" + "[11,7,5,1,1,3,2,1,1]" + "[3,4,7,8][0,1,2,3,4,5,6,7,8]"); + + // [1] [11, 7, 5, 1, 1, 3, 2, 1] + EXPECT_EQ(BCast({1}, {11, 7, 5, 1, 1, 3, 2, 1, 1}), + "[1][2310][2310][1]" + "[2310]" + "[11,7,5,1,1,3,2,1,1]" + "[0,1,2,3,4,5,6,7,8][3,4,7,8]"); +} + +TEST(BCastTest, Basic_Tensor_Vector) { + // [11, 7, 5, 3, 2] [2] + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {2}), + "[1155,2][1,1][1,2][1155,1]" + "[1155,2]" + "[11,7,5,3,2]" + "[][0,1,2,3]"); + + // [2] [11, 7, 5, 3, 2] + EXPECT_EQ(BCast({2}, {11, 7, 5, 3, 2}), + "[1,2][1155,1][1155,2][1,1]" + "[1155,2]" + "[11,7,5,3,2]" + "[0,1,2,3][]"); +} + +TEST(BCastTest, Basic_Tensor_Matrix) { + // [11, 7, 5, 3, 2] [3, 2] + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 2}), + "[385,6][1,1][1,6][385,1]" + "[385,6]" + "[11,7,5,3,2]" + "[][0,1,2]"); + // [3, 2] [11, 7, 5, 3, 2] + EXPECT_EQ(BCast({3, 2}, {11, 7, 5, 3, 2}), + "[1,6][385,1][385,6][1,1]" + "[385,6]" + "[11,7,5,3,2]" + "[0,1,2][]"); +} + +TEST(BCastTest, Basic_Tensor_Matrix_Column) { + // [11, 7, 5, 3, 2] [3, 1] + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 1}), + "[385,3,2][1,1,1][1,3,1][385,1,2]" + "[385,3,2]" + "[11,7,5,3,2]" + "[][0,1,2,4]"); + + // [3, 1] [11, 7, 5, 3, 2] + EXPECT_EQ(BCast({3, 1}, {11, 7, 5, 3, 2}), + "[1,3,1][385,1,2][385,3,2][1,1,1]" + "[385,3,2]" + "[11,7,5,3,2]" + "[0,1,2,4][]"); +} + +TEST(BCastTest, Basic_Tensor_Matrix_As_Tensor) { + // [11, 7, 5, 3, 2] [7, 5, 1, 1] + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {7, 5, 1, 1}), + "[11,35,6][1,1,1][1,35,1][11,1,6]" + "[11,35,6]" + "[11,7,5,3,2]" + "[][0,3,4]"); + + // [7, 5, 1, 1] [11, 7, 5, 3, 2] + EXPECT_EQ(BCast({7, 5, 1, 1}, {11, 7, 5, 3, 2}), + "[1,35,1][11,1,6][11,35,6][1,1,1]" + "[11,35,6]" + "[11,7,5,3,2]" + "[0,3,4][]"); +} + +TEST(BCastTest, Complex_BCast_To_Each_Other) { + // Rare cases. x and y broadcast to each other. x and y are of + // different ranks. + // Can be verified in numpy as: + // import numpy as np + // x = np.arange(0,110).reshape([11,1,5,1,2]) + // y = np.arange(0,21).reshape([7,1,3,1]) + // np.shape(x + y) + // Out[.]: (11, 7, 5, 3, 2) + EXPECT_EQ(BCast({11, 1, 5, 1, 2}, {7, 1, 3, 1}), + "[11,1,5,1,2][1,7,1,3,1][1,7,1,3,1][11,1,5,1,2]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[1,3][0,2,4]"); +} + +TEST(BCastTest, TestZeroDimensionShape) { + EXPECT_EQ(BCast({2, 0, 5}, {5}), + "[0,5][1,1][1,5][0,1]" + "[0,5]" + "[2,0,5]" + "[][0,1]"); + EXPECT_EQ(BCast({5}, {2, 0, 5}), + "[1,5][0,1][0,5][1,1]" + "[0,5]" + "[2,0,5]" + "[0,1][]"); + + EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {5}), + "[0,5][1,1][1,5][0,1]" + "[0,5]" + "[2,0,3,0,5]" + "[][0,1,2,3]"); + EXPECT_EQ(BCast({5}, {2, 0, 3, 0, 5}), + "[1,5][0,1][0,5][1,1]" + "[0,5]" + "[2,0,3,0,5]" + "[0,1,2,3][]"); + + EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {3, 1, 5}), + "[0,3,0,5][1,1,1,1][1,3,1,5][0,1,0,1]" + "[0,3,0,5]" + "[2,0,3,0,5]" + "[][0,1,3]"); + EXPECT_EQ(BCast({3, 1, 5}, {2, 0, 3, 0, 5}), + "[1,3,1,5][0,1,0,1][0,3,0,5][1,1,1,1]" + "[0,3,0,5]" + "[2,0,3,0,5]" + "[0,1,3][]"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc new file mode 100644 index 0000000000..b8c6a77dd0 --- /dev/null +++ b/tensorflow/core/util/device_name_utils.cc @@ -0,0 +1,338 @@ +#include "tensorflow/core/util/device_name_utils.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +static bool IsAlpha(char c) { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); +} + +static bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); } + +// Returns true iff "in" is a valid job name. +static bool IsJobName(StringPiece in) { + if (in.empty()) return false; + if (!IsAlpha(in[0])) return false; + for (size_t i = 1; i < in.size(); ++i) { + if (!(IsAlphaNum(in[i]) || in[i] == '_')) return false; + } + return true; +} + +// Returns true and fills in "*job" iff "*in" starts with a job name. +static bool ConsumeJobName(StringPiece* in, string* job) { + if (in->empty()) return false; + if (!IsAlpha((*in)[0])) return false; + size_t i = 1; + for (; i < in->size(); ++i) { + const char c = (*in)[i]; + if (c == '/') break; + if (!(IsAlphaNum(c) || c == '_')) { + return false; + } + } + job->assign(in->data(), i); + in->remove_prefix(i); + return true; +} + +// Returns true and fills in "*device_type" iff "*in" starts with a device type +// name. +static bool ConsumeDeviceType(StringPiece* in, string* device_type) { + if (in->empty()) return false; + if (!IsAlpha((*in)[0])) return false; + size_t i = 1; + for (; i < in->size(); ++i) { + const char c = (*in)[i]; + if (c == '/' || c == ':') break; + if (!(IsAlphaNum(c) || c == '_')) { + return false; + } + } + device_type->assign(in->data(), i); + in->remove_prefix(i); + return true; +} + +// Returns true and fills in "*val" iff "*in" starts with a decimal +// number. +static bool ConsumeNumber(StringPiece* in, int* val) { + uint64 tmp; + if (str_util::ConsumeLeadingDigits(in, &tmp)) { + *val = tmp; + return true; + } else { + return false; + } +} + +/* static */ +string DeviceNameUtils::FullName(const string& job, int replica, int task, + const string& type, int id) { + CHECK(IsJobName(job)) << job; + CHECK_LE(0, replica); + CHECK_LE(0, task); + CHECK(!type.empty()); + CHECK_LE(0, id); + return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task, + "/device:", type, ":", id); +} + +bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) { + p->Clear(); + if (fullname == "/") { + return true; + } + StringPiece tmp; + while (!fullname.empty()) { + if (str_util::ConsumePrefix(&fullname, "/job:")) { + p->has_job = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_job && !ConsumeJobName(&fullname, &p->job)) { + return false; + } + } else if (str_util::ConsumePrefix(&fullname, "/replica:")) { + p->has_replica = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) { + return false; + } + } else if (str_util::ConsumePrefix(&fullname, "/task:")) { + p->has_task = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_task && !ConsumeNumber(&fullname, &p->task)) { + return false; + } + } else if (str_util::ConsumePrefix(&fullname, "/device:")) { + p->has_type = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) { + return false; + } + if (!str_util::ConsumePrefix(&fullname, ":")) { + p->has_id = false; + } else { + p->has_id = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_id && !ConsumeNumber(&fullname, &p->id)) { + return false; + } + } + + } else if (str_util::ConsumePrefix(&fullname, "/cpu:") || + str_util::ConsumePrefix(&fullname, "/CPU:")) { + p->has_type = true; + p->type = "CPU"; // Treat '/cpu:..' as uppercase '/device:CPU:...' + p->has_id = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_id && !ConsumeNumber(&fullname, &p->id)) { + return false; + } + } else if (str_util::ConsumePrefix(&fullname, "/gpu:") || + str_util::ConsumePrefix(&fullname, "/GPU:")) { + p->has_type = true; + p->type = "GPU"; // Treat '/gpu:..' as uppercase '/device:GPU:...' + p->has_id = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_id && !ConsumeNumber(&fullname, &p->id)) { + return false; + } + } else { + return false; + } + } + return true; +} + +/* static */ +string DeviceNameUtils::ParsedNameToString(const ParsedName& pn) { + string buf; + if (pn.has_job) strings::StrAppend(&buf, "/job:", pn.job); + if (pn.has_replica) strings::StrAppend(&buf, "/replica:", pn.replica); + if (pn.has_task) strings::StrAppend(&buf, "/task:", pn.task); + if (pn.has_type) { + strings::StrAppend(&buf, "/", pn.type, ":"); + if (pn.has_id) { + strings::StrAppend(&buf, pn.id); + } else { + strings::StrAppend(&buf, "*"); + } + } + return buf; +} + +/* static */ +bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific, + const ParsedName& more_specific) { + if (less_specific.has_job && + (!more_specific.has_job || (less_specific.job != more_specific.job))) { + return false; + } + if (less_specific.has_replica && + (!more_specific.has_replica || + (less_specific.replica != more_specific.replica))) { + return false; + } + if (less_specific.has_task && + (!more_specific.has_task || (less_specific.task != more_specific.task))) { + return false; + } + if (less_specific.has_type && + (!more_specific.has_type || (less_specific.type != more_specific.type))) { + return false; + } + if (less_specific.has_id && + (!more_specific.has_id || (less_specific.id != more_specific.id))) { + return false; + } + return true; +} + +/* static */ +bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern, + const ParsedName& name) { + CHECK(name.has_job && name.has_replica && name.has_task && name.has_type && + name.has_id); + + if (pattern.has_job && (pattern.job != name.job)) return false; + if (pattern.has_replica && (pattern.replica != name.replica)) return false; + if (pattern.has_task && (pattern.task != name.task)) return false; + if (pattern.has_type && (pattern.type != name.type)) return false; + if (pattern.has_id && (pattern.id != name.id)) return false; + return true; +} + +/* static */ +Status DeviceNameUtils::MergeDevNames(ParsedName* target, + const ParsedName& other, + bool allow_soft_placement) { + if (other.has_job) { + if (target->has_job && target->job != other.job) { + return errors::InvalidArgument( + "Cannot merge devices with incompatible jobs: '", + ParsedNameToString(*target), "' and '", ParsedNameToString(other), + "'"); + } else { + target->has_job = other.has_job; + target->job = other.job; + } + } + + if (other.has_replica) { + if (target->has_replica && target->replica != other.replica) { + return errors::InvalidArgument( + "Cannot merge devices with incompatible replicas: '", + ParsedNameToString(*target), "' and '", ParsedNameToString(other), + "'"); + } else { + target->has_replica = other.has_replica; + target->replica = other.replica; + } + } + + if (other.has_task) { + if (target->has_task && target->task != other.task) { + return errors::InvalidArgument( + "Cannot merge devices with incompatible tasks: '", + ParsedNameToString(*target), "' and '", ParsedNameToString(other), + "'"); + } else { + target->has_task = other.has_task; + target->task = other.task; + } + } + + if (other.has_type) { + if (target->has_type && target->type != other.type) { + if (!allow_soft_placement) { + return errors::InvalidArgument( + "Cannot merge devices with incompatible types: '", + ParsedNameToString(*target), "' and '", ParsedNameToString(other), + "'"); + } else { + target->has_id = false; + target->has_type = false; + return Status::OK(); + } + } else { + target->has_type = other.has_type; + target->type = other.type; + } + } + + if (other.has_id) { + if (target->has_id && target->id != other.id) { + if (!allow_soft_placement) { + return errors::InvalidArgument( + "Cannot merge devices with incompatible ids: '", + ParsedNameToString(*target), "' and '", ParsedNameToString(other), + "'"); + } else { + target->has_id = false; + return Status::OK(); + } + } else { + target->has_id = other.has_id; + target->id = other.id; + } + } + + return Status::OK(); +} + +/* static */ +bool DeviceNameUtils::IsSameAddressSpace(const ParsedName& a, + const ParsedName& b) { + return (a.has_job && b.has_job && (a.job == b.job)) && + (a.has_replica && b.has_replica && (a.replica == b.replica)) && + (a.has_task && b.has_task && (a.task == b.task)); +} + +/* static */ +bool DeviceNameUtils::IsSameAddressSpace(StringPiece src, StringPiece dst) { + ParsedName x; + ParsedName y; + return ParseFullName(src, &x) && ParseFullName(dst, &y) && + IsSameAddressSpace(x, y); +} + +/* static */ +string DeviceNameUtils::LocalName(StringPiece type, int id) { + return strings::StrCat(type, ":", id); +} + +/* static */ +string DeviceNameUtils::LocalName(StringPiece fullname) { + ParsedName x; + CHECK(ParseFullName(fullname, &x)) << fullname; + return LocalName(x.type, x.id); +} + +/* static */ +bool DeviceNameUtils::ParseLocalName(StringPiece name, ParsedName* p) { + ParsedName x; + if (!ConsumeDeviceType(&name, &p->type)) { + return false; + } + if (!str_util::ConsumePrefix(&name, ":")) { + return false; + } + if (!ConsumeNumber(&name, &p->id)) { + return false; + } + return name.empty(); +} + +/* static */ +bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task, + string* device) { + ParsedName pn; + if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) { + *task = strings::StrCat( + (pn.has_job ? strings::StrCat("/job:", pn.job) : ""), + (pn.has_replica ? strings::StrCat("/replica:", pn.replica) : ""), + (pn.has_task ? strings::StrCat("/task:", pn.task) : "")); + *device = strings::StrCat(pn.type, ":", pn.id); + return true; + } + return false; +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h new file mode 100644 index 0000000000..8b0a24ed0d --- /dev/null +++ b/tensorflow/core/util/device_name_utils.h @@ -0,0 +1,141 @@ +#ifndef TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_ +#define TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_ + +#include + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// In TensorFlow a device name is a string of the following form: +// /job:/replica:/task:/device:: +// +// is a short identifier conforming to the regexp +// [a-zA-Z][_a-zA-Z]* +// is a supported device type (e.g. 'cpu' or 'gpu') +// , , are small non-negative integers and are +// densely allocated (except in tests). +// +// For some purposes, we also allow device patterns, which can specify +// some or none of the specific fields above, with missing components, +// or ":*" indicating "any value allowed for that component. +// +// For example: +// "/job:param_server" - Consider any devices in the "param_server" job +// "/device:cpu:*" - Consider any cpu devices in any job/task/replica +// "/job:*/replica:*/task:*/device:cpu:*" - Consider any cpu devices in any +// job/task/replica +// "/job:w/replica:0/task:0/device:gpu:*" - Consider any gpu devices in +// replica 0, task 0, of job "w" +class DeviceNameUtils { + public: + // Returns a fully qualified device name given the parameters. + static string FullName(const string& job, int replica, int task, + const string& type, int id); + + struct ParsedName { + void Clear() { + has_job = false; + has_replica = false; + has_task = false; + has_type = false; + has_id = false; + job.clear(); + replica = 0; + task = 0; + type.clear(); + id = 0; + } + + bool operator==(const ParsedName& other) const { + return (has_job ? (other.has_job && job == other.job) : !other.has_job) && + (has_replica ? (other.has_replica && replica == other.replica) + : !other.has_replica) && + (has_task ? (other.has_task && task == other.task) + : !other.has_task) && + (has_type ? (other.has_type && type == other.type) + : !other.has_type) && + (has_id ? (other.has_id && id == other.id) : !other.has_id); + } + + bool has_job = false; + string job; + bool has_replica = false; + int replica = 0; + bool has_task = false; + int task = 0; + bool has_type = false; + string type; + bool has_id = false; + int id = 0; + }; + // Parses "fullname" into "*parsed". Returns true iff succeeds. + static bool ParseFullName(StringPiece fullname, ParsedName* parsed); + + // Returns true if "name" specifies any non-trivial constraint on the device. + static bool HasSomeDetails(const ParsedName& name) { + return name.has_job || name.has_replica || name.has_task || name.has_type || + name.has_id; + } + + // Returns true if more_specific is a specification of + // less_specific, i.e. everywhere that less-specific has a + // non-wildcard component value, more_specific has the same value + // for that component. + static bool IsSpecification(const ParsedName& less_specific, + const ParsedName& more_specific); + + // Like IsSpecification, but the second argument "name" must have a + // non-wildcard value for all of its components. + static bool IsCompleteSpecification(const ParsedName& pattern, + const ParsedName& name); + + // True iff there exists any possible complete device name that is + // a specification of both "a" and "b". + static inline bool AreCompatibleDevNames(const ParsedName& a, + const ParsedName& b) { + return IsSpecification(a, b) || IsSpecification(b, a); + } + + // Merges the device specifications in "*target" and "other", and + // stores the result in "*target". Returns OK if "*target" and + // "other" are compatible, otherwise returns an error. + static Status MergeDevNames(ParsedName* target, const ParsedName& other) { + return MergeDevNames(target, other, false); + } + static Status MergeDevNames(ParsedName* target, const ParsedName& other, + bool allow_soft_placement); + + // Returns true iff devices identified by 'src' and 'dst' are in the + // same address space. + static bool IsSameAddressSpace(StringPiece src, StringPiece dst); + static bool IsSameAddressSpace(const ParsedName& src, const ParsedName& dst); + + // Returns the local device given its "type" and "id". + static string LocalName(StringPiece type, int id); + + // Returns a short local device name (cpu:0, gpu:1, etc) based on + // the given fullname. + static string LocalName(StringPiece fullname); + + // If "name" is a valid local device name (cpu:0, gpu:1, etc.), + // fills in parsed.type and parsed.id accordingly. Returns true iff + // succeeds. + static bool ParseLocalName(StringPiece name, ParsedName* parsed); + + // Splits a fully-qualified device name into a task identifier and a + // relative device identifier. It first parses "name" using + // ParseFullName(), then assigns *task with everything except for + // the local device component, and assigns the relative device + // component into *device. This function will still return true if + // the task component is empty, but it requires the relative device + // component to be fully specified. + static bool SplitDeviceName(StringPiece name, string* task, string* device); + + static string ParsedNameToString(const ParsedName& pn); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_ diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc new file mode 100644 index 0000000000..14f30d6de5 --- /dev/null +++ b/tensorflow/core/util/device_name_utils_test.cc @@ -0,0 +1,369 @@ +#include "tensorflow/core/util/device_name_utils.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include + +namespace tensorflow { + +TEST(DeviceNameUtilsTest, Basic) { + EXPECT_EQ(DeviceNameUtils::FullName("hello", 1, 2, "CPU", 3), + "/job:hello/replica:1/task:2/device:CPU:3"); + + { + DeviceNameUtils::ParsedName p; + EXPECT_FALSE(DeviceNameUtils::ParseFullName("foobar", &p)); + EXPECT_FALSE( + DeviceNameUtils::ParseFullName("/job:123/replica:1/task:2/gpu:3", &p)); + EXPECT_FALSE( + DeviceNameUtils::ParseFullName("/job:123/replica:1/task:2/gpu:", &p)); + EXPECT_FALSE(DeviceNameUtils::ParseFullName( + "/job:123/replica:1/task:2/device:gpu:", &p)); + EXPECT_FALSE( + DeviceNameUtils::ParseFullName("/job:foo/replica:-1/task:2/gpu:3", &p)); + EXPECT_FALSE( + DeviceNameUtils::ParseFullName("/job:foo/replica:1/task:-2/gpu:3", &p)); + EXPECT_FALSE( + DeviceNameUtils::ParseFullName("/job:foo/replica:1/task:2/bar:3", &p)); + EXPECT_FALSE(DeviceNameUtils::ParseFullName( + "/job:foo/replica:1/task:2/gpu:3/extra", &p)); + EXPECT_TRUE( + DeviceNameUtils::ParseFullName("/job:foo/replica:1/task:2/gpu:3", &p)); + EXPECT_TRUE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_TRUE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_TRUE(p.has_id); + EXPECT_EQ(p.job, "foo"); + EXPECT_EQ(p.replica, 1); + EXPECT_EQ(p.task, 2); + EXPECT_EQ(p.type, "GPU"); + EXPECT_EQ(p.id, 3); + } + { + // Allow _ in job names. + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName( + "/job:foo_bar/replica:1/task:2/gpu:3", &p)); + EXPECT_TRUE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_TRUE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_TRUE(p.has_id); + EXPECT_EQ(p.job, "foo_bar"); + EXPECT_EQ(p.replica, 1); + EXPECT_EQ(p.task, 2); + EXPECT_EQ(p.type, "GPU"); + EXPECT_EQ(p.id, 3); + } + { + // Allow _ in job names. + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName( + "/job:foo_bar/replica:1/task:2/device:GPU:3", &p)); + EXPECT_TRUE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_TRUE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_TRUE(p.has_id); + EXPECT_EQ(p.job, "foo_bar"); + EXPECT_EQ(p.replica, 1); + EXPECT_EQ(p.task, 2); + EXPECT_EQ(p.type, "GPU"); + EXPECT_EQ(p.id, 3); + } + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName("/job:*/replica:4/gpu:*", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_FALSE(p.has_id); + EXPECT_EQ(p.replica, 4); + EXPECT_EQ(p.type, "GPU"); + } + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE( + DeviceNameUtils::ParseFullName("/job:*/replica:4/device:GPU:*", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_FALSE(p.has_id); + EXPECT_EQ(p.replica, 4); + EXPECT_EQ(p.type, "GPU"); + } + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE( + DeviceNameUtils::ParseFullName("/job:*/device:GPU/replica:4", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_FALSE(p.has_id); + EXPECT_EQ(p.replica, 4); + EXPECT_EQ(p.type, "GPU"); + } + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName( + "/job:*/replica:4/device:myspecialdevice:13", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_TRUE(p.has_id); + EXPECT_EQ(p.replica, 4); + EXPECT_EQ(p.type, "myspecialdevice"); + EXPECT_EQ(p.id, 13); + } + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName("/", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_FALSE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_FALSE(p.has_type); + EXPECT_FALSE(p.has_id); + } + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName("/job:*/replica:4/gpu:5", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_TRUE(p.has_id); + EXPECT_EQ(p.replica, 4); + EXPECT_EQ(p.type, "GPU"); + EXPECT_EQ(p.id, 5); + } + { // Same result if we reorder the components + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName("/gpu:*/job:*/replica:4", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_FALSE(p.has_id); + EXPECT_EQ(p.replica, 4); + EXPECT_EQ(p.type, "GPU"); + } + + EXPECT_TRUE(DeviceNameUtils::IsSameAddressSpace( + "/job:foo/replica:1/task:2/cpu:3", "/job:foo/replica:1/task:2/gpu:4")); + EXPECT_FALSE(DeviceNameUtils::IsSameAddressSpace( + "/job:foo/replica:1/task:2/cpu:3", "/job:foo/replica:1/task:3/gpu:4")); + EXPECT_FALSE(DeviceNameUtils::IsSameAddressSpace( + "/job:foo/replica:1/task:2/cpu:3", "/job:foo/replica:10/task:2/gpu:4")); + EXPECT_FALSE(DeviceNameUtils::IsSameAddressSpace( + "/job:foo/replica:1/task:2/cpu:3", "/job:bar/replica:1/task:2/gpu:4")); + + EXPECT_EQ(DeviceNameUtils::LocalName("CPU", 1), "CPU:1"); + EXPECT_EQ(DeviceNameUtils::LocalName("GPU", 2), "GPU:2"); + EXPECT_EQ(DeviceNameUtils::LocalName("MySpecialDevice", 13), + "MySpecialDevice:13"); + + EXPECT_EQ( + DeviceNameUtils::LocalName("/job:foo/replica:1/task:2/device:CPU:3"), + "CPU:3"); + + EXPECT_EQ(DeviceNameUtils::LocalName("/job:foo/replica:1/task:2/cpu:3"), + "CPU:3"); + + EXPECT_EQ( + DeviceNameUtils::LocalName("/job:foo/replica:1/task:2/device:abc:73"), + "abc:73"); + + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseLocalName("CPU:10", &p)); + EXPECT_EQ(p.type, "CPU"); + EXPECT_EQ(p.id, 10); + EXPECT_FALSE(DeviceNameUtils::ParseLocalName("cpu:abc", &p)); + EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc:", &p)); + EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc", &p)); + EXPECT_FALSE(DeviceNameUtils::ParseLocalName("myspecialdevice", &p)); + } +} + +static bool IsCSHelper(StringPiece pattern, StringPiece actual) { + DeviceNameUtils::ParsedName p, a; + EXPECT_TRUE(DeviceNameUtils::ParseFullName(pattern, &p)); + EXPECT_TRUE(DeviceNameUtils::ParseFullName(actual, &a)); + return DeviceNameUtils::IsCompleteSpecification(p, a); +} + +TEST(DeviceNameUtilsTest, IsCompleteSpecification) { + EXPECT_TRUE(IsCSHelper("/job:*", "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE( + IsCSHelper("/job:*/replica:*", "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsCSHelper("/job:*/task:*", "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsCSHelper("/job:*/replica:*/task:*", + "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE( + IsCSHelper("/job:*/replica:*/gpu:*", "/job:work/replica:1/task:2/gpu:3")); + EXPECT_FALSE(IsCSHelper("/cpu:*", "/job:worker/replica:1/task:2/gpu:3")); + EXPECT_FALSE(IsCSHelper("/gpu:2", "/job:worker/replica:1/task:2/gpu:1")); + EXPECT_TRUE(IsCSHelper("/gpu:*", "/job:worker/replica:1/task:2/gpu:3")); +} + +static bool IsSpecHelper(StringPiece pattern, StringPiece actual) { + DeviceNameUtils::ParsedName p, a; + EXPECT_TRUE(DeviceNameUtils::ParseFullName(pattern, &p)); + EXPECT_TRUE(DeviceNameUtils::ParseFullName(actual, &a)); + return DeviceNameUtils::IsSpecification(p, a); +} + +TEST(DeviceNameUtilsTest, IsSpecification) { + EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work/replica:1/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work/replica:1")); + EXPECT_TRUE(IsSpecHelper("/job:*", "/replica:1")); + EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work")); + EXPECT_TRUE( + IsSpecHelper("/job:*/replica:*", "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/job:work/replica:1/gpu:*", + "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/job:work/replica:1/gpu:3", + "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/job:work/replica:1/task:2", + "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/job:work/replica:*/task:2", + "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/task:*", "/job:*/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/task:2", "/job:*/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/cpu:*", "/job:*/replica:1/task:2/cpu:1")); + EXPECT_TRUE(IsSpecHelper("/cpu:0", "/cpu:0")); + EXPECT_TRUE(IsSpecHelper("/gpu:*", "/job:worker/replica:1/task:2/gpu:3")); + + EXPECT_FALSE(IsSpecHelper("/job:worker/replica:1/task:2/gpu:3", "/gpu:*")); + EXPECT_FALSE(IsSpecHelper("/cpu:*", "/job:*/replica:1/task:2")); + EXPECT_FALSE(IsSpecHelper("/cpu:*", "/job:*/replica:1/task:2/gpu:1")); + EXPECT_FALSE(IsSpecHelper("/cpu:*", "/job:worker/replica:1/task:2/gpu:3")); + EXPECT_FALSE(IsSpecHelper("/gpu:2", "/job:worker/replica:1/task:2/gpu:1")); + EXPECT_FALSE(IsSpecHelper("/job:work/replica:*/task:0", + "/job:work/replica:1/task:2/gpu:3")); + EXPECT_FALSE(IsSpecHelper("/job:work/replica:0/task:2", + "/job:work/replica:*/task:2/gpu:3")); +} + +TEST(DeviceNameUtilsTest, SplitDeviceName) { + string task; + string device; + EXPECT_TRUE(DeviceNameUtils::SplitDeviceName( + "/job:foo/replica:1/task:2/cpu:1", &task, &device)); + EXPECT_EQ("/job:foo/replica:1/task:2", task); + EXPECT_EQ("CPU:1", device); + EXPECT_TRUE(DeviceNameUtils::SplitDeviceName( + "/job:foo/cpu:1/task:2/replica:1", &task, &device)); + EXPECT_EQ("/job:foo/replica:1/task:2", task); + EXPECT_EQ("CPU:1", device); + EXPECT_TRUE(DeviceNameUtils::SplitDeviceName("/gpu:3", &task, &device)); + EXPECT_EQ("", task); + EXPECT_EQ("GPU:3", device); + EXPECT_FALSE(DeviceNameUtils::SplitDeviceName("gpu:3", &task, &device)); + EXPECT_FALSE(DeviceNameUtils::SplitDeviceName("/job:foo/task:2/replica:1", + &task, &device)); + EXPECT_TRUE(DeviceNameUtils::SplitDeviceName("/device:myspecialdevice:3", + &task, &device)); + EXPECT_EQ("", task); + EXPECT_EQ("myspecialdevice:3", device); +} + +static DeviceNameUtils::ParsedName Name(const string& str) { + DeviceNameUtils::ParsedName ret; + CHECK(DeviceNameUtils::ParseFullName(str, &ret)) << "Invalid name: " << str; + return ret; +} + +static void MergeDevNamesHelperImpl(const string& name_a, const string& name_b, + const string& expected_merge_name, + bool allow_soft_placement) { + DeviceNameUtils::ParsedName target_a = Name(name_a); + EXPECT_OK(DeviceNameUtils::MergeDevNames(&target_a, Name(name_b), + allow_soft_placement)); + DeviceNameUtils::ParsedName target_b = Name(name_b); + EXPECT_OK(DeviceNameUtils::MergeDevNames(&target_b, Name(name_a), + allow_soft_placement)); + EXPECT_EQ(target_a, target_b); + EXPECT_EQ(target_a, Name(expected_merge_name)); + EXPECT_EQ(target_b, Name(expected_merge_name)); +} + +static void MergeDevNamesHelper(const string& name_a, const string& name_b, + const string& expected_merge_name) { + MergeDevNamesHelperImpl(name_a, name_b, expected_merge_name, false); +} + +static void MergeDevNamesHelperAllowSoftPlacement( + const string& name_a, const string& name_b, + const string& expected_merge_name) { + MergeDevNamesHelperImpl(name_a, name_b, expected_merge_name, true); +} + +static void MergeDevNamesError(const string& name_a, const string& name_b, + const string& expected_error_substr) { + DeviceNameUtils::ParsedName target_a = Name(name_a); + Status s = DeviceNameUtils::MergeDevNames(&target_a, Name(name_b)); + EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); + EXPECT_TRUE(StringPiece(s.error_message()).contains(expected_error_substr)) + << s; +} + +TEST(DeviceNameUtilsTest, MergeDevNames) { + DeviceNameUtils::ParsedName target; + + // Idempotence tests. + MergeDevNamesHelper("", "", ""); + MergeDevNamesHelper("/job:foo/replica:1/task:2/cpu:1", + "/job:foo/replica:1/task:2/cpu:1", + "/job:foo/replica:1/task:2/cpu:1"); + + // Merging with empty device has no effect. + MergeDevNamesHelper("", "/job:foo", "/job:foo"); + MergeDevNamesHelper("", "/replica:2", "/replica:2"); + MergeDevNamesHelper("", "/task:7", "/task:7"); + // MergeDevNamesHelper("", "/gpu:1", "/gpu:1"); + + // Combining disjoint names. + MergeDevNamesHelper("/job:foo", "/task:7", "/job:foo/task:7"); + MergeDevNamesHelper("/job:foo", "/gpu:1", "/job:foo/gpu:1"); + + // Combining overlapping names. + MergeDevNamesHelper("/job:foo/replica:0", "/replica:0/task:1", + "/job:foo/replica:0/task:1"); + + // Wildcard tests. + MergeDevNamesHelper("", "/gpu:*", "/gpu:*"); + MergeDevNamesHelper("/gpu:*", "/gpu:*", "/gpu:*"); + MergeDevNamesHelper("/gpu:1", "/gpu:*", "/gpu:1"); + + // Incompatible components. + MergeDevNamesError("/job:foo", "/job:bar", "incompatible jobs"); + MergeDevNamesError("/replica:0", "/replica:1", "incompatible replicas"); + MergeDevNamesError("/task:0", "/task:1", "incompatible tasks"); + MergeDevNamesError("/gpu:*", "/cpu:*", "incompatible types"); + MergeDevNamesError("/gpu:0", "/gpu:1", "incompatible ids"); +} + +TEST(DeviceNameUtilsTest, MergeDevNamesAllowSoftPlacement) { + // Incompatible components with allow_soft_placement. + MergeDevNamesHelperAllowSoftPlacement("/gpu:*", "/cpu:1", ""); + MergeDevNamesHelperAllowSoftPlacement("/cpu:*", "/gpu:1", ""); + MergeDevNamesHelperAllowSoftPlacement("/gpu:1", "/gpu:2", "/gpu:*"); +} + +static void BM_ParseFullName(int iters) { + DeviceNameUtils::ParsedName p; + while (iters--) { + DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p); + } +} +BENCHMARK(BM_ParseFullName); + +} // namespace tensorflow diff --git a/tensorflow/core/util/event.proto b/tensorflow/core/util/event.proto new file mode 100644 index 0000000000..5d67823ce7 --- /dev/null +++ b/tensorflow/core/util/event.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/framework/summary.proto"; + +// Protocol buffer representing an event that happened during +// the execution of a Brain model. +message Event { + // Timestamp of the event. + double wall_time = 1; + + // Globale step of the event. + int64 step = 2; + + oneof what { + // An event file was started, with the specified version. + // This is use to identify the contents of the record IO files + // easily. Current version is "tensorflow.Event:1". All versions + // start with "tensorflow.Event:". + string file_version = 3; + // A model was constructed. + GraphDef graph_def = 4; + // A summary was generated. + Summary summary = 5; + } +} diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc new file mode 100644 index 0000000000..1b34a36577 --- /dev/null +++ b/tensorflow/core/util/events_writer.cc @@ -0,0 +1,144 @@ +#include "tensorflow/core/util/events_writer.h" + +#include // for NULL + +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/util/event.pb.h" + +namespace tensorflow { + +EventsWriter::EventsWriter(const string& file_prefix) + // TODO(jeff,sanjay): Pass in env and use that here instead of Env::Default + : env_(Env::Default()), + file_prefix_(file_prefix), + num_outstanding_events_(0) {} + +bool EventsWriter::Init() { + if (recordio_writer_.get() != nullptr) { + CHECK(!filename_.empty()); + if (FileHasDisappeared()) { + // Warn user of data loss and let .reset() below do basic cleanup. + if (num_outstanding_events_ > 0) { + LOG(WARNING) << "Re-intialization, attempting to open a new file, " + << num_outstanding_events_ << " events will be lost."; + } + } else { + // No-op: File is present and writer is initialized. + return true; + } + } + + int64 time_in_seconds = env_->NowMicros() / 1000000; + + filename_ = strings::Printf( + "%s.out.tfevents.%010lld.%s", file_prefix_.c_str(), + static_cast(time_in_seconds), port::Hostname().c_str()); + port::AdjustFilenameForLogging(&filename_); + + WritableFile* file; + Status s = env_->NewWritableFile(filename_, &file); + if (!s.ok()) { + LOG(ERROR) << "Could not open events file: " << filename_ << ": " << s; + return false; + } + recordio_file_.reset(file); + recordio_writer_.reset(new io::RecordWriter(recordio_file_.get())); + if (recordio_writer_.get() == NULL) { + LOG(ERROR) << "Could not create record writer"; + return false; + } + num_outstanding_events_ = 0; + VLOG(1) << "Successfully opened events file: " << filename_; + { + // Write the first event with the current version, and flush + // right away so the file contents will be easily determined. + + Event event; + event.set_wall_time(time_in_seconds); + event.set_file_version(strings::StrCat(kVersionPrefix, kCurrentVersion)); + WriteEvent(event); + Flush(); + } + return true; +} + +string EventsWriter::FileName() { + if (filename_.empty()) { + Init(); + } + return filename_; +} + +void EventsWriter::WriteSerializedEvent(const string& event_str) { + if (recordio_writer_.get() == NULL) { + if (!Init()) { + LOG(ERROR) << "Write failed because file could not be opened."; + return; + } + } + num_outstanding_events_++; + recordio_writer_->WriteRecord(event_str); +} + +void EventsWriter::WriteEvent(const Event& event) { + string record; + event.AppendToString(&record); + WriteSerializedEvent(record); +} + +bool EventsWriter::Flush() { + if (num_outstanding_events_ == 0) return true; + CHECK(recordio_file_.get() != NULL) << "Unexpected NULL file"; + // The FileHasDisappeared() condition is necessary because + // recordio_writer_->Sync() can return true even if the underlying + // file has been deleted. EventWriter.FileDeletionBeforeWriting + // demonstrates this and will fail if the FileHasDisappeared() + // conditon is removed. + // Also, we deliberately attempt to Sync() before checking for a + // disappearing file, in case for some file system File::Exists() is + // false after File::Open() but before File::Sync(). + if (!recordio_file_->Flush().ok() || !recordio_file_->Sync().ok() || + FileHasDisappeared()) { + LOG(ERROR) << "Failed to flush " << num_outstanding_events_ << " events to " + << filename_; + return false; + } + VLOG(1) << "Wrote " << num_outstanding_events_ << " events to disk."; + num_outstanding_events_ = 0; + return true; +} + +bool EventsWriter::Close() { + bool return_value = Flush(); + if (recordio_file_.get() != NULL) { + Status s = recordio_file_->Close(); + if (!s.ok()) { + LOG(ERROR) << "Error when closing previous event file: " << filename_ + << ": " << s; + return_value = false; + } + recordio_writer_.reset(NULL); + recordio_file_.reset(NULL); + } + num_outstanding_events_ = 0; + return return_value; +} + +bool EventsWriter::FileHasDisappeared() { + if (env_->FileExists(filename_)) { + return false; + } else { + // This can happen even with non-null recordio_writer_ if some other + // process has removed the file. + LOG(ERROR) << "The events file " << filename_ << " has disappeared."; + return true; + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/events_writer.h b/tensorflow/core/util/events_writer.h new file mode 100644 index 0000000000..e6b94ad265 --- /dev/null +++ b/tensorflow/core/util/events_writer.h @@ -0,0 +1,77 @@ +#ifndef TENSORFLOW_UTIL_EVENTS_WRITER_H_ +#define TENSORFLOW_UTIL_EVENTS_WRITER_H_ + +#include +#include +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/util/event.pb.h" + +namespace tensorflow { + +class EventsWriter { + public: +#ifndef SWIG + // Prefix of version string present in the first entry of every event file. + static constexpr const char* kVersionPrefix = "brain.Event:"; + static constexpr const int kCurrentVersion = 1; +#endif + + // Events files typically have a name of the form + // '/some/file/path/my.file.out.events.[timestamp].[hostname]' + // To create and EventWriter, the user should provide file_prefix = + // '/some/file/path/my.file' + // The EventsWriter will append '.out.events.[timestamp].[hostname]' + // to the ultimate filename once Init() is called. + // Note that it is not recommended to simultaneously have two + // EventWriters writing to the same file_prefix. + explicit EventsWriter(const string& file_prefix); + ~EventsWriter() { Close(); } // Autoclose in destructor. + + // Sets the event file filename and opens file for writing. If not called by + // user, will be invoked automatically by a call to FileName() or Write*(). + // Returns false if the file could not be opened. Idempotent: if file exists + // and is open this is a no-op. If on the other hand the file was opened, + // but has since disappeared (e.g. deleted by another process), this will open + // a new file with a new timestamp in its filename. + bool Init(); + + // Returns the filename for the current events file: + // filename_ = [file_prefix_].out.events.[timestamp].[hostname] + string FileName(); + + // Append "event" to the file. The "tensorflow::" part is for swig happiness. + void WriteEvent(const tensorflow::Event& event); + + // Append "event_str", a serialized Event, to the file. + // Note that this function does NOT check that de-serializing event_str + // results in a valid Event proto. + void WriteSerializedEvent(const string& event_str); + + // EventWriter automatically flushes and closes on destruction, but + // these two methods are provided for users who want to write to disk sooner + // and/or check for success. + // Flush() pushes outstanding events to disk. Returns false if the + // events file could not be created, or if the file exists but could not + // be written too. + // Close() calls Flush() and then closes the current events file. + // Returns true only if both the flush and the closure were successful. + bool Flush(); + bool Close(); + + private: + bool FileHasDisappeared(); // True if event_file_path_ does not exist. + + Env* env_; + const string file_prefix_; + string filename_; + std::unique_ptr recordio_file_; + std::unique_ptr recordio_writer_; + int num_outstanding_events_; + TF_DISALLOW_COPY_AND_ASSIGN(EventsWriter); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_EVENTS_WRITER_H_ diff --git a/tensorflow/core/util/events_writer_test.cc b/tensorflow/core/util/events_writer_test.cc new file mode 100644 index 0000000000..f6523ead92 --- /dev/null +++ b/tensorflow/core/util/events_writer_test.cc @@ -0,0 +1,198 @@ +#include "tensorflow/core/util/events_writer.h" + +#include +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/util/event.pb.h" + +namespace tensorflow { +namespace { + +// shorthand +Env* env() { return Env::Default(); } + +void WriteSimpleValue(EventsWriter* writer, double wall_time, int64 step, + const string& tag, float simple_value) { + Event event; + event.set_wall_time(wall_time); + event.set_step(step); + Summary::Value* summ_val = event.mutable_summary()->add_value(); + summ_val->set_tag(tag); + summ_val->set_simple_value(simple_value); + writer->WriteEvent(event); +} + +void WriteFile(EventsWriter* writer) { + WriteSimpleValue(writer, 1234, 34, "foo", 3.14159); + WriteSimpleValue(writer, 2345, 35, "bar", -42); +} + +static bool ReadEventProto(io::RecordReader* reader, uint64* offset, + Event* proto) { + string record; + Status s = reader->ReadRecord(offset, &record); + if (!s.ok()) { + return false; + } + return ParseProtoUnlimited(proto, record); +} + +void VerifyFile(const string& filename) { + CHECK(env()->FileExists(filename)); + RandomAccessFile* event_file; + TF_CHECK_OK(env()->NewRandomAccessFile(filename, &event_file)); + io::RecordReader* reader = new io::RecordReader(event_file); + + uint64 offset = 0; + + Event actual; + CHECK(ReadEventProto(reader, &offset, &actual)); + VLOG(1) << actual.ShortDebugString(); + // Wall time should be within 5s of now. + + double current_time = env()->NowMicros() / 1000000.0; + EXPECT_LT(fabs(actual.wall_time() - current_time), 5); + // Should have the current version number. + EXPECT_EQ(actual.file_version(), + strings::StrCat(EventsWriter::kVersionPrefix, + EventsWriter::kCurrentVersion)); + + Event expected; + CHECK(ReadEventProto(reader, &offset, &actual)); + VLOG(1) << actual.ShortDebugString(); + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "wall_time: 1234 step: 34 " + "summary { value { tag: 'foo' simple_value: 3.14159 } }", + &expected)); + // TODO(keveman): Enable this check + // EXPECT_THAT(expected, EqualsProto(actual)); + + CHECK(ReadEventProto(reader, &offset, &actual)); + VLOG(1) << actual.ShortDebugString(); + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "wall_time: 2345 step: 35 " + "summary { value { tag: 'bar' simple_value: -42 } }", + &expected)); + // TODO(keveman): Enable this check + // EXPECT_THAT(expected, EqualsProto(actual)); + + TF_CHECK_OK(env()->DeleteFile(filename)); + + delete reader; + delete event_file; +} + +string GetDirName(const string& suffix) { + return io::JoinPath(testing::TmpDir(), suffix); +} + +TEST(EventWriter, WriteFlush) { + string file_prefix = GetDirName("/writeflush_test"); + EventsWriter writer(file_prefix); + WriteFile(&writer); + EXPECT_TRUE(writer.Flush()); + string filename = writer.FileName(); + VerifyFile(filename); +} + +TEST(EventWriter, WriteClose) { + string file_prefix = GetDirName("/writeclose_test"); + EventsWriter writer(file_prefix); + WriteFile(&writer); + EXPECT_TRUE(writer.Close()); + string filename = writer.FileName(); + VerifyFile(filename); +} + +TEST(EventWriter, WriteDelete) { + string file_prefix = GetDirName("/writedelete_test"); + EventsWriter* writer = new EventsWriter(file_prefix); + WriteFile(writer); + string filename = writer->FileName(); + delete writer; + VerifyFile(filename); +} + +TEST(EventWriter, FailFlush) { + string file_prefix = GetDirName("/failflush_test"); + EventsWriter writer(file_prefix); + string filename = writer.FileName(); + WriteFile(&writer); + EXPECT_TRUE(env()->FileExists(filename)); + env()->DeleteFile(filename); + EXPECT_FALSE(env()->FileExists(filename)); + EXPECT_FALSE(writer.Flush()); + EXPECT_FALSE(env()->FileExists(filename)); +} + +TEST(EventWriter, FailClose) { + string file_prefix = GetDirName("/failclose_test"); + EventsWriter writer(file_prefix); + string filename = writer.FileName(); + WriteFile(&writer); + EXPECT_TRUE(env()->FileExists(filename)); + env()->DeleteFile(filename); + EXPECT_FALSE(env()->FileExists(filename)); + EXPECT_FALSE(writer.Close()); + EXPECT_FALSE(env()->FileExists(filename)); +} + +TEST(EventWriter, InitWriteClose) { + string file_prefix = GetDirName("/initwriteclose_test"); + EventsWriter writer(file_prefix); + EXPECT_TRUE(writer.Init()); + string filename0 = writer.FileName(); + EXPECT_TRUE(env()->FileExists(filename0)); + WriteFile(&writer); + EXPECT_TRUE(writer.Close()); + string filename1 = writer.FileName(); + EXPECT_EQ(filename0, filename1); + VerifyFile(filename1); +} + +TEST(EventWriter, NameWriteClose) { + string file_prefix = GetDirName("/namewriteclose_test"); + EventsWriter writer(file_prefix); + string filename = writer.FileName(); + EXPECT_TRUE(env()->FileExists(filename)); + WriteFile(&writer); + EXPECT_TRUE(writer.Close()); + VerifyFile(filename); +} + +TEST(EventWriter, NameClose) { + string file_prefix = GetDirName("/nameclose_test"); + EventsWriter writer(file_prefix); + string filename = writer.FileName(); + EXPECT_TRUE(writer.Close()); + EXPECT_TRUE(env()->FileExists(filename)); + env()->DeleteFile(filename); +} + +TEST(EventWriter, FileDeletionBeforeWriting) { + string file_prefix = GetDirName("/fdbw_test"); + EventsWriter writer(file_prefix); + string filename0 = writer.FileName(); + EXPECT_TRUE(env()->FileExists(filename0)); + env()->SleepForMicroseconds( + 2000000); // To make sure timestamp part of filename will differ. + env()->DeleteFile(filename0); + EXPECT_TRUE(writer.Init()); // Init should reopen file. + WriteFile(&writer); + EXPECT_TRUE(writer.Flush()); + string filename1 = writer.FileName(); + EXPECT_NE(filename0, filename1); + VerifyFile(filename1); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/util/guarded_philox_random.cc b/tensorflow/core/util/guarded_philox_random.cc new file mode 100644 index 0000000000..4cf58b8979 --- /dev/null +++ b/tensorflow/core/util/guarded_philox_random.cc @@ -0,0 +1,39 @@ +#include "tensorflow/core/util/guarded_philox_random.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { + +Status GuardedPhiloxRandom::Init(OpKernelConstruction* context) { + // Grab seed Attrs. + int64 seed, seed2; + auto status = context->GetAttr("seed", &seed); + if (!status.ok()) return status; + status = context->GetAttr("seed2", &seed2); + if (!status.ok()) return status; + + // Initialize with the given seeds + Init(seed, seed2); + return Status::OK(); +} + +void GuardedPhiloxRandom::Init(int64 seed, int64 seed2) { + CHECK(!initialized_); + if (seed == 0 && seed2 == 0) { + // If both seeds are unspecified, use completely random seeds. + seed = random::New64(); + seed2 = random::New64(); + } + mutex_lock lock(mu_); + generator_ = random::PhiloxRandom(seed, seed2); + initialized_ = true; +} + +random::PhiloxRandom GuardedPhiloxRandom::ReserveSamples128(int64 samples) { + CHECK(initialized_); + mutex_lock lock(mu_); + auto local = generator_; + generator_.Skip(samples); + return local; +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/guarded_philox_random.h b/tensorflow/core/util/guarded_philox_random.h new file mode 100644 index 0000000000..6e9cb9f99c --- /dev/null +++ b/tensorflow/core/util/guarded_philox_random.h @@ -0,0 +1,56 @@ +#ifndef TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_ +#define TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// A thread safe wrapper around a Philox generator. Example usage: +// +// GuardedRandomPhilox generator; +// generator.Init(context); +// +// // In thread safe code +// const int samples = ...; +// auto local_generator = generator.ReserveSamples128(samples); +// for (int i = 0; i < samples; i++) +// Array sample = local_generator(); +// // Use sample +// } +// +class GuardedPhiloxRandom { + public: + // Must call Init to finish initialization + GuardedPhiloxRandom() : initialized_(false) {} + + // Initialize the generator from attributes "seed" and "seed2". + // If both seeds are unspecified, use random seeds. + // Must be called exactly once. + Status Init(OpKernelConstruction* context); + + // Initialize with given seeds. + void Init(int64 seed, int64 seed2); + + // Reserve a certain number of 128-bit samples. + // This function is thread safe. The returned generator is valid for the + // given number of samples, and can be used without a lock. + random::PhiloxRandom ReserveSamples128(int64 samples); + + // Reserve a certain number of 32-bit samples + random::PhiloxRandom ReserveSamples32(int64 samples) { + return ReserveSamples128((samples + 3) / 4); + } + + private: + mutex mu_; + random::PhiloxRandom generator_ GUARDED_BY(mu_); + bool initialized_; + + TF_DISALLOW_COPY_AND_ASSIGN(GuardedPhiloxRandom); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_ diff --git a/tensorflow/core/util/padding.cc b/tensorflow/core/util/padding.cc new file mode 100644 index 0000000000..24273e5ca4 --- /dev/null +++ b/tensorflow/core/util/padding.cc @@ -0,0 +1,24 @@ +#include "tensorflow/core/util/padding.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +Status GetNodeAttr(const NodeDef& node_def, const string& attr_name, + Padding* value) { + string str_value; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attr_name, &str_value)); + if (str_value == "SAME") { + *value = SAME; + } else if (str_value == "VALID") { + *value = VALID; + } else { + return errors::NotFound(str_value, " is not an allowed padding type"); + } + return Status::OK(); +} + +string GetPaddingAttrString() { return "padding: {'SAME', 'VALID'}"; } + +} // end namespace tensorflow diff --git a/tensorflow/core/util/padding.h b/tensorflow/core/util/padding.h new file mode 100644 index 0000000000..66cd96abdb --- /dev/null +++ b/tensorflow/core/util/padding.h @@ -0,0 +1,37 @@ +#ifndef TENSORFLOW_UTIL_PADDING_H_ +#define TENSORFLOW_UTIL_PADDING_H_ + +// This file contains helper routines to deal with padding in various ops and +// kernels. + +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Padding: the padding we apply to the input tensor along the rows and columns +// dimensions. This is usually used to make sure that the spatial dimensions do +// not shrink when we progress with convolutions. Two types of padding are +// supported: +// VALID: No padding is carried out. +// SAME: The pad value is computed so that the output will have the same +// dimensions as the input. +// The padded area is zero-filled. +enum Padding { + VALID = 1, // No padding. + SAME = 2, // Input and output layers have the same size. +}; + +// Return the string containing the list of valid padding types, that can be +// used as an Attr() in REGISTER_OP. +string GetPaddingAttrString(); + +// Specialization to parse an attribute directly into a Padding enum. +Status GetNodeAttr(const NodeDef& node_def, const string& attr_name, + Padding* value); + +} // end namespace tensorflow + +#endif // TENSORFLOW_UTIL_PADDING_H_ diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc new file mode 100644 index 0000000000..12eb076a4d --- /dev/null +++ b/tensorflow/core/util/port.cc @@ -0,0 +1,13 @@ +#include "tensorflow/core/util/port.h" + +namespace tensorflow { + +bool IsGoogleCudaEnabled() { +#if GOOGLE_CUDA + return true; +#else + return false; +#endif +} + +} // end namespace tensorflow diff --git a/tensorflow/core/util/port.h b/tensorflow/core/util/port.h new file mode 100644 index 0000000000..8b9d033d63 --- /dev/null +++ b/tensorflow/core/util/port.h @@ -0,0 +1,11 @@ +#ifndef TENSORFLOW_UTIL_PORT_H_ +#define TENSORFLOW_UTIL_PORT_H_ + +namespace tensorflow { + +// Returns true if GOOGLE_CUDA is defined. +bool IsGoogleCudaEnabled(); + +} // end namespace tensorflow + +#endif // TENSORFLOW_UTIL_PORT_H_ diff --git a/tensorflow/core/util/saved_tensor_slice.proto b/tensorflow/core/util/saved_tensor_slice.proto new file mode 100644 index 0000000000..f6599d9669 --- /dev/null +++ b/tensorflow/core/util/saved_tensor_slice.proto @@ -0,0 +1,76 @@ +// Protocol buffers for saved tensor slices. It's used for the brain tensor +// ops checkpoints and the V3 checkpoints in dist_belief. + +// A checkpoint file is an sstable. The value for each record is a serialized +// SavedTensorSlices message (defined below). +// +// Each checkpoint file has a record with the empty key (""), which corresponds +// to a SavedTensorSlices message that contains a "meta", that serves as a +// table of contents on all the tensor slices saved in this file. Since the key +// is "", it's always the first record in each file. +// +// Each of the rest of the records in a checkpoint stores the raw data of a +// particular tensor slice, in SavedSlice format. The corresponding key is an +// ordered code that encodes the name of the tensor and the slice +// information. The name is also stored in the SaveSlice message for ease of +// debugging and manual examination. + +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/tensor_slice.proto"; +import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/types.proto"; + +// Metadata describing the set of slices of the same tensor saved in a +// checkpoint file. +message SavedSliceMeta { + // Name of the tensor. + string name = 1; + + // Shape of the tensor + TensorShapeProto shape = 2; + + // Type of the tensor + DataType type = 3; + + // Explicit list of slices saved in the checkpoint file. + repeated TensorSliceProto slice = 4; +}; + +// Metadata describing the set of tensor slices saved in a checkpoint file. +// It is always stored at the beginning of each checkpoint file. +message SavedTensorSliceMeta { + // Each SavedSliceMeta describes the slices for one tensor. + repeated SavedSliceMeta tensor = 1; +}; + +// Saved tensor slice: it stores the name of the tensors, the slice, and the +// raw data. +message SavedSlice { + // Name of the tensor that this slice belongs to. This must be identical to + // the name used to encode the key for this record. + string name = 1; + + // Extent of the slice. Must have one entry for each of the dimension of the + // tensor that this slice belongs to. + TensorSliceProto slice = 2; + + // The raw data of the slice is stored as a TensorProto. Only raw data are + // stored (we don't fill in fields such as dtype or tensor_shape). + TensorProto data = 3; +}; + +// Each record in a v3 checkpoint file is a serialized SavedTensorSlices +// message. +message SavedTensorSlices { + // This is only present at the first item of each checkpoint file and serves + // as a table of contents, listing all the tensor slices saved in this file. + SavedTensorSliceMeta meta = 1; + + // This exists in all but the first item of each checkpoint file. + SavedSlice data = 2; +}; diff --git a/tensorflow/core/util/saved_tensor_slice_util.cc b/tensorflow/core/util/saved_tensor_slice_util.cc new file mode 100644 index 0000000000..7a5903f07f --- /dev/null +++ b/tensorflow/core/util/saved_tensor_slice_util.cc @@ -0,0 +1,76 @@ +#include "tensorflow/core/util/saved_tensor_slice_util.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/ordered_code.h" + +namespace tensorflow { + +namespace checkpoint { + +const char kSavedTensorSlicesKey[] = ""; + +string EncodeTensorNameSlice(const string& name, const TensorSlice& slice) { + string buffer; + // All the tensor slice keys will start with a 0 + tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, 0); + tensorflow::strings::OrderedCode::WriteString(&buffer, name); + tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, slice.dims()); + for (int d = 0; d < slice.dims(); ++d) { + // A trivial extent (meaning we take EVERYTHING) will default to -1 for both + // start and end. These will be properly parsed. + tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer, + slice.start(d)); + tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer, + slice.length(d)); + } + return buffer; +} + +Status DecodeTensorNameSlice(const string& code, string* name, + tensorflow::TensorSlice* slice) { + StringPiece src(code); + uint64 x; + if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) { + return errors::Internal("Failed to parse the leading number: src = ", src); + } + if (x != 0) { + return errors::Internal( + "The leading number should always be 0 for any valid key: src = ", src); + } + if (!tensorflow::strings::OrderedCode::ReadString(&src, name)) { + return errors::Internal("Failed to parse the tensor name: src = ", src); + } + if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) { + return errors::Internal("Failed to parse the tensor rank: src = ", src); + } + if (x == 0) { + return errors::Internal("Expecting positive rank of the tensor, got ", x, + ", src = ", src); + } + if (x >= kint32max) { + return errors::Internal("Too many elements ", x); + } + slice->SetFullSlice(x); + for (int d = 0; d < static_cast(x); ++d) { + // We expected 2x integers + int64 start, length; + if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src, + &start)) { + return errors::Internal("Failed to parse start: src = ", src); + } + if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src, + &length)) { + return errors::Internal("Failed to parse length: src = ", src); + } + if (length >= 0) { + // a non-trivial extent + slice->set_start(d, start); + slice->set_length(d, length); + } + } + return Status::OK(); +} + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h new file mode 100644 index 0000000000..6206cd8538 --- /dev/null +++ b/tensorflow/core/util/saved_tensor_slice_util.h @@ -0,0 +1,110 @@ +// Utilities for saving/restoring tensor slice checkpoints. + +#ifndef TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ +#define TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ + +#include // for string +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/status.h" // for Status + +namespace tensorflow { + +namespace checkpoint { + +// The key for the metadata in the tensor slice checkpoint files. It is "" so +// that the metadata is always at the beginning of a checkpoint file. +extern const char kSavedTensorSlicesKey[]; + +// Encode a tensor name + a tensor slice into an ordered code and outputs it as +// a string. +// The format is +// <0> +// +// +// +// +// ... + +string EncodeTensorNameSlice(const string& name, + const tensorflow::TensorSlice& slice); + +// Parse out the name and the slice from string encoded as an ordered code. +Status DecodeTensorNameSlice(const string& code, string* name, + tensorflow::TensorSlice* slice); + +template +struct SaveTypeTraits; + +template +const typename SaveTypeTraits::SavedType* TensorProtoData( + const TensorProto& t); + +template +protobuf::RepeatedField::SavedType>* +MutableTensorProtoData(TensorProto* t); + +template +void Fill(T* data, size_t n, TensorProto* t); + +#define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \ + template <> \ + struct SaveTypeTraits { \ + static constexpr bool supported = true; \ + typedef FTYPE SavedType; \ + }; \ + template <> \ + inline const FTYPE* TensorProtoData(const TensorProto& t) { \ + static_assert(SaveTypeTraits::supported, \ + "Specified type " #TYPE " not supported for Restore"); \ + return reinterpret_cast(t.FIELD##_val().data()); \ + } \ + template <> \ + inline protobuf::RepeatedField* MutableTensorProtoData( \ + TensorProto * t) { \ + static_assert(SaveTypeTraits::supported, \ + "Specified type " #TYPE " not supported for Save"); \ + return reinterpret_cast*>( \ + t->mutable_##FIELD##_val()); \ + } \ + template <> \ + inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ + typename protobuf::RepeatedField copy(data, data + n); \ + t->mutable_##FIELD##_val()->Swap(©); \ + } + +TENSOR_PROTO_EXTRACT_TYPE(float, float, float); +TENSOR_PROTO_EXTRACT_TYPE(double, double, double); +TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32); +TENSOR_PROTO_EXTRACT_TYPE(int64, int64, int64); +TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32); +TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32); +TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32); +TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32); +TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32); + +#undef TENSOR_PROTO_EXTRACT_TYPE + +template <> +struct SaveTypeTraits : SaveTypeTraits {}; + +template <> +inline const int32* TensorProtoData(const TensorProto& t) { + static_assert(SaveTypeTraits::supported, + "Specified type qint32 not supported for Restore"); + return reinterpret_cast(t.int_val().data()); +} + +inline void Fill(const qint32* data, size_t n, TensorProto* t) { + const int32* p = reinterpret_cast(data); + typename protobuf::RepeatedField copy(p, p + n); + t->mutable_int_val()->Swap(©); +} + +} // namespace checkpoint + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ diff --git a/tensorflow/core/util/saved_tensor_slice_util_test.cc b/tensorflow/core/util/saved_tensor_slice_util_test.cc new file mode 100644 index 0000000000..2c34c903db --- /dev/null +++ b/tensorflow/core/util/saved_tensor_slice_util_test.cc @@ -0,0 +1,32 @@ +#include "tensorflow/core/util/saved_tensor_slice_util.h" + +#include +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +namespace checkpoint { + +namespace { + +// Testing serialization of tensor name and tensor slice in the ordered code +// format. +TEST(TensorShapeUtilTest, TensorNameSliceToOrderedCode) { + { + TensorSlice s = TensorSlice::ParseOrDie("-:-:1,3:4,5"); + string buffer = EncodeTensorNameSlice("foo", s); + string name; + s.Clear(); + TF_CHECK_OK(DecodeTensorNameSlice(buffer, &name, &s)); + EXPECT_EQ("foo", name); + EXPECT_EQ("-:-:1,3:4,5", s.DebugString()); + } +} + +} // namespace + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/sparse/README.md b/tensorflow/core/util/sparse/README.md new file mode 100644 index 0000000000..7b0799eb0e --- /dev/null +++ b/tensorflow/core/util/sparse/README.md @@ -0,0 +1,222 @@ +SparseTensor +============ + +Sparse Tensors are stored as two dense tensors and a shape: + +* `indices`: a `brain::Tensor` storing a matrix, typically `int64` +* `values`: a `brain::Tensor` storing a vector with values of type T. +* `shape`: a `TensorShape` storing the bounds of the underlying tensor +* `order`: (optional) a `gtl::InlinedVector` with the dimensions + along which the indices are ordered. + +Let + + ix = indices.matrix() + vals = values.vec() + +The shape of `ix` is `N x NDIMS`, and each row corresponds to the +index of a single element of the sparse tensor. + +The length of `vals` must be `N`, and `vals(i)` corresponds to the +value with index `ix(i,:)`. + +Shape must be a `TensorShape` with `dims() == NDIMS`. +The shape is the full shape of the dense tensor these indices +represent. + +To be specific, the representation (pseudocode) is: + + tensor[ix[i,:]] == vals[i] for i = 0, ..., N-1 + +Ordering +-------- + +Indices need not be provided in order. For example, the following +index matrix is ordered according to dimension order `{0, 1, 2}`. + + [0 0 1] + [0 1 1] + [2 0 2] + +However, you can provide an unordered version: + + [2 0 2] + [0 0 1] + [0 1 1] + +If the SparseTensor is constructed without a provided order, then a +the default order is `{-1, ..., -1}`. Certain operations will fail or crash +when the order is not provided. + +Resorting the SparseTensor in-place (which resorts the underlying index and +values tensors in-place) will update the order. The cost of reordering the +matrix is `O(N*log(N))`, and requires `O(N)` additional temporary space to store +a reordering index. If the default order is not specified and reordering is not +performed, the following will happen: + +* `group()` will **raise an assertion failure** +* `IndicesValid()` will **raise an assertion failure** + +To update the internal index ordering after construction, call +`Reorder()` via, e.g., `Reorder({0,1,2})`. +After this step, all the above methods should work correctly. + +The method `IndicesValid()` checks to make sure: + +* `0 <= ix(i, d) < shape.dim_size(d)` +* indices do not repeat +* indices are in order + +Iterating +--------- + +### group({grouping dims}) + +* provides an iterator that groups entries according to + dimensions you care about +* may require a sort if your data isn't presorted in a way that's + compatible with grouping_dims +* for each group, returns the group index (values of the group + dims for this iteration), the subset of indices in this group, + and the subset of values in this group. these are lazy outputs + so to read them individually, copy them as per the example + below. + +#### **NOTE** +`group({dim0, ..., dimk})` will **raise an assertion failure** if the +order of the SparseTensor does not match the dimensions you wish to group by. +You must either have your indices in the correct order and construct the +SparseTensor with + + order = {dim0, ..., dimk, ...} + +or call + + Reorder({dim0, .., dimk, ...}) + +to sort the SparseTensor before grouping. + +Example of grouping: + + Tensor indices(DT_INT64, TensorShape({N, NDIMS}); + Tensor values(DT_STRING, TensorShape({N}); + TensorShape shape({dim0,...}); + SparseTensor sp(indices, vals, shape); + sp.Reorder({1, 2, 0, 3, ...}); // Must provide NDIMS dims. + // group according to dims 1 and 2 + for (const auto& g : sp.group({1, 2})) { + cout << "vals of ix[:, 1,2] for this group: " + << g.group()[0] << ", " << g.group()[1]; + cout << "full indices of group:\n" << g.indices(); + cout << "values of group:\n" << g.values(); + + TTypes::UnalignedMatrix g_ix = g.indices(); + TTypes::UnalignedVec g_v = g.values(); + ASSERT(g_ix.dimension(0) == g_v.size()); // number of elements match. + } + + +ToDense +-------- + +Converts sparse tensor to dense. You must provide a pointer to the +dense tensor (preallocated). `ToDense()` will optionally +preinitialize the tensor with zeros. + +Shape checking is performed, as is boundary checking. + + Tensor indices(DT_INT64, TensorShape({N, NDIMS}); + Tensor values(DT_STRING, TensorShape({N}); + TensorShape shape({dim0,...}); + SparseTensor sp(indices, vals, shape); + ASSERT(sp.IndicesValid()); // checks ordering & index bounds. + + Tensor dense(DT_STRING, shape); + // initialize other indices to zero. copy. + ASSERT(sp.ToDense(&dense, true)); + + +Concat +-------- + +Concatenates multiple SparseTensors and returns a new SparseTensor. +This concatenation is with respect to the "dense" versions of these +SparseTensors. Concatenation is performed along dimension order[0] +of all tensors. As a result, shape[order[0]] may differ across +the inputs, but shape[d] for d != order[0] must match across all inputs. + +We call order[0] the **primary dimension**. + +**Prerequisites** + +* The inputs' ranks must all match. +* The inputs' order[0] must all match. +* The inputs' shapes must all match except for dimension order[0]. +* The inputs' values must all be of the same type. + +If any of these are false, concat will die with an assertion failure. + +Example: +Concatenate two sparse matrices along columns. + +Matrix 1: + + [0 0 1] + [2 0 0] + [3 0 4] + +Matrix 2: + + [0 0 0 0 0] + [0 1 0 0 0] + [2 0 0 1 0] + +Concatenated Matrix: + + [0 0 1 0 0 0 0 0] + [2 0 0 0 1 0 0 0] + [3 0 4 2 0 0 1 0] + +Expected input shapes, orders, and `nnz()`: + + shape_1 = TensorShape({3, 3}) + shape_2 = TensorShape({3, 8}) + order_1 = {1, 0} // primary order is 1, columns + order_2 = {1, 0} // primary order is 1, must match + nnz_1 = 4 + nnz_2 = 3 + +Output shapes and orders: + + conc_shape = TensorShape({3, 11}) // primary dim increased, others same + conc_order = {1, 0} // Orders match along all inputs + conc_nnz = 7 // Sum of nonzeros of inputs + +Coding Example: + + Tensor ix1(DT_INT64, TensorShape({N1, 3}); + Tensor vals1(DT_STRING, TensorShape({N1, 3}); + Tensor ix2(DT_INT64, TensorShape({N2, 3}); + Tensor vals2(DT_STRING, TensorShape({N2, 3}); + Tensor ix3(DT_INT64, TensorShape({N3, 3}); + Tensor vals3(DT_STRING, TensorShape({N3, 3}); + + SparseTensor st1(ix1, vals1, TensorShape({10, 20, 5}), {1, 0, 2}); + SparseTensor st2(ix2, vals2, TensorShape({10, 10, 5}), {1, 0, 2}); + // For kicks, st3 indices are out of order, but order[0] matches so we + // can still concatenate along this dimension. + SparseTensor st3(ix3, vals3, TensorShape({10, 30, 5}), {1, 2, 0}); + + SparseTensor conc = SparseTensor::Concat({st1, st2, st3}); + Tensor ix_conc = conc.indices(); + Tensor vals_conc = conc.values(); + EXPECT_EQ(conc.nnz(), st1.nnz() + st2.nnz() + st3.nnz()); + EXPECT_EQ(conc.Shape(), TensorShape({10, 60, 5})); + EXPECT_EQ(conc.Order(), {-1, -1, -1}); + + // Reorder st3 so all input tensors have the exact same orders. + st3.Reorder({1, 0, 2}); + SparseTensor conc2 = SparseTensor::Concat({st1, st2, st3}); + EXPECT_EQ(conc2.Order(), {1, 0, 2}); + // All indices' orders matched, so output is in order. + EXPECT_TRUE(conc2.IndicesValid()); diff --git a/tensorflow/core/util/sparse/dim_comparator.h b/tensorflow/core/util/sparse/dim_comparator.h new file mode 100644 index 0000000000..57473867cf --- /dev/null +++ b/tensorflow/core/util/sparse/dim_comparator.h @@ -0,0 +1,60 @@ +#ifndef TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_ +#define TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace sparse { + +///////////////// +// DimComparator +///////////////// +// +// Helper class, mainly used by the IndexSortOrder. This comparator +// can be passed to e.g. std::sort, or any other sorter, to sort two +// rows of an index matrix according to the dimension(s) of interest. +// The dimensions to sort by are passed to the constructor as "order". +// +// Example: if given index matrix IX, two rows ai and bi, and order = {2,1}. +// operator() compares +// IX(ai,2) < IX(bi,2). +// If IX(ai,2) == IX(bi,2), it compares +// IX(ai,1) < IX(bi,1). +// +// This can be used to sort a vector of row indices into IX according to +// the values in IX in particular columns (dimensions) of interest. +class DimComparator { + public: + typedef typename gtl::ArraySlice VarDimArray; + + inline DimComparator(const TTypes::Matrix& ix, + const VarDimArray& order, int dims) + : ix_(ix), order_(order), dims_(dims) { + CHECK_GT(order.size(), 0) << "Must order using at least one index"; + CHECK_LE(order.size(), dims_) << "Can only sort up to dims"; + for (size_t d = 0; d < order.size(); ++d) { + CHECK_GE(order[d], 0); + CHECK_LT(order[d], dims); + } + } + + inline bool operator()(const int64 i, const int64 j) const { + for (int di = 0; di < dims_; ++di) { + const int64 d = order_[di]; + if (ix_(i, d) < ix_(j, d)) return true; + if (ix_(i, d) > ix_(j, d)) return false; + } + return false; + } + + const TTypes::Matrix ix_; + const VarDimArray order_; + const int dims_; +}; + +} // namespace sparse +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_ diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc new file mode 100644 index 0000000000..e153bcdbb4 --- /dev/null +++ b/tensorflow/core/util/sparse/group_iterator.cc @@ -0,0 +1,49 @@ +#include "tensorflow/core/util/sparse/group_iterator.h" + +namespace tensorflow { +namespace sparse { + +void GroupIterable::IteratorStep::UpdateEndOfGroup() { + ++next_loc_; + int64 N = iter_->ix_.dim_size(0); + auto ix_t = iter_->ix_.template matrix(); + while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) { + ++next_loc_; + } +} + +bool GroupIterable::IteratorStep::operator!=(const IteratorStep& rhs) const { + CHECK_EQ(rhs.iter_, iter_) << "Can't compare steps from different iterators"; + return (rhs.loc_ != loc_); +} + +GroupIterable::IteratorStep& GroupIterable::IteratorStep:: +operator++() { // prefix ++ + loc_ = next_loc_; + UpdateEndOfGroup(); + return *this; +} + +GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++( + int) { // postfix ++ + IteratorStep lhs(*this); + ++(*this); + return lhs; +} + +std::vector Group::group() const { + std::vector g; + auto ix_t = iter_->ix_.template matrix(); + for (const int d : iter_->group_dims_) { + g.push_back(ix_t(loc_, d)); + } + return g; +} + +TTypes::UnalignedConstMatrix Group::indices() const { + return TTypes::UnalignedConstMatrix( + &(iter_->ix_.matrix()(loc_, 0)), next_loc_ - loc_, iter_->dims_); +} + +} // namespace sparse +} // namespace tensorflow diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h new file mode 100644 index 0000000000..8423d54f27 --- /dev/null +++ b/tensorflow/core/util/sparse/group_iterator.h @@ -0,0 +1,120 @@ +#ifndef TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_ +#define TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace sparse { + +class GroupIterable; // Predeclare GroupIterable for Group. + +// This class is returned when dereferencing a GroupIterable iterator. +// It provides the methods group(), indices(), and values(), which +// provide access into the underlying SparseTensor. +class Group { + public: + Group(GroupIterable* iter, int64 loc, int64 next_loc) + : iter_(iter), loc_(loc), next_loc_(next_loc) {} + + std::vector group() const; + TTypes::UnalignedConstMatrix indices() const; + template + typename TTypes::UnalignedVec values() const; + + private: + GroupIterable* iter_; + int64 loc_; + int64 next_loc_; +}; + +///////////////// +// GroupIterable +///////////////// +// +// Returned when calling sparse_tensor.group({dim0, dim1, ...}). +// +// Please note: the sparse_tensor should already be ordered according +// to {dim0, dim1, ...}. Otherwise this iteration will return invalid groups. +// +// Allows grouping and iteration of the SparseTensor according to the +// subset of dimensions provided to the group call. +// +// The actual grouping dimensions are stored in the +// internal vector group_dims_. Iterators inside the iterable provide +// the three methods: +// +// * group(): returns a vector with the current group dimension values. +// * indices(): a map of index, providing the indices in +// this group. +// * values(): a map of values, providing the values in +// this group. +// +// To iterate across GroupIterable, see examples in README.md. +// + +// Forward declaration of SparseTensor +class GroupIterable { + public: + typedef gtl::ArraySlice VarDimArray; + + GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims) + : ix_(ix), vals_(vals), dims_(dims), group_dims_(group_dims) {} + + class IteratorStep; + + IteratorStep begin() { return IteratorStep(this, 0); } + IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); } + + template + inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const { + bool matches = true; + for (int d : group_dims_) { + if (ix(loc_a, d) != ix(loc_b, d)) { + matches = false; + } + } + return matches; + } + + class IteratorStep { + public: + IteratorStep(GroupIterable* iter, int64 loc) + : iter_(iter), loc_(loc), next_loc_(loc_) { + UpdateEndOfGroup(); + } + + void UpdateEndOfGroup(); + bool operator!=(const IteratorStep& rhs) const; + IteratorStep& operator++(); // prefix ++ + IteratorStep operator++(int); // postfix ++ + Group operator*() const { return Group(iter_, loc_, next_loc_); } + + private: + GroupIterable* iter_; + int64 loc_; + int64 next_loc_; + }; + + private: + friend class Group; + Tensor ix_; + Tensor vals_; + const int dims_; + const VarDimArray group_dims_; +}; + +// Implementation of Group::values() +template +typename TTypes::UnalignedVec Group::values() const { + return typename TTypes::UnalignedVec(&(iter_->vals_.vec()(loc_)), + next_loc_ - loc_); +} + +} // namespace sparse +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_ diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h new file mode 100644 index 0000000000..dcb75e7f54 --- /dev/null +++ b/tensorflow/core/util/sparse/sparse_tensor.h @@ -0,0 +1,353 @@ +#ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ +#define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ + +#include + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/sparse/dim_comparator.h" +#include "tensorflow/core/util/sparse/group_iterator.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace sparse { + +class SparseTensor { + public: + typedef typename gtl::ArraySlice VarDimArray; + + SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape) + : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {} + + SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape, + const VarDimArray& order) + : ix_(ix), + vals_(vals), + shape_(shape), + order_(order.begin(), order.end()), + dims_(GetDimsFromIx(ix)) { + CHECK_EQ(ix.dtype(), DT_INT64) << "indices must be type int64 but got: " + << ix.dtype(); + CHECK(TensorShapeUtils::IsMatrix(ix.shape())) + << "indices must be a matrix, but got: " << ix.shape().DebugString(); + CHECK(TensorShapeUtils::IsVector(vals.shape())) + << "vals must be a vec, but got: " << vals.shape().DebugString(); + CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0)) + << "indices and values rows (indexing dimension) must match."; + } + + std::size_t num_entries() const { return ix_.dim_size(0); } + + const Tensor& indices() const { return ix_; } + + const Tensor& values() const { return vals_; } + + DataType dtype() const { return vals_.dtype(); } + + bool IndicesValid() const { + const auto ix_t = ix_.matrix(); + for (int64 ord : order_) { + CHECK_GE(ord, 0) << "Order was not provided. Provide an order at " + "construction time or run ReorderInPlace"; + } + + for (std::size_t n = 0; n < num_entries(); ++n) { + if (!IndexValid(ix_t, n)) return false; + } + + return true; + } + + // Returns the tensor shape (the dimensions of the "densified" + // tensor this tensor represents). + const TensorShape shape() const { return shape_; } + + const VarDimArray order() const { return order_; } + + // Resorts the indices and values according to the dimensions in order. + template + void Reorder(const VarDimArray& order); + + // Returns a group iterable that can be used for clumping indices + // and values according to the group indices of interest. + // + // Precondition: order()[0..group_ix.size()] == group_ix. + // + // See the README.md in this directory for more usage information. + GroupIterable group(const VarDimArray& group_ix) { + CHECK_LE(group_ix.size(), dims_); + for (std::size_t di = 0; di < group_ix.size(); ++di) { + CHECK_GE(group_ix[di], 0) << "Group dimension out of range"; + CHECK_LT(group_ix[di], dims_) << "Group dimension out of range"; + CHECK_EQ(group_ix[di], order_[di]) + << "Group dimension does not match sorted order"; + } + return GroupIterable(ix_, vals_, dims_, group_ix); + } + + // Stores the sparse indices into the dense tensor out. + // Preconditions: + // out->shape().dims() == shape().dims() + // out->shape().dim_size(d) >= shape(d) for all d + // + // Returns true on success. False on failure (mismatched dimensions + // or out-of-bounds indices). + // + // If initialize==True, ToDense first overwrites all coefficients in out to 0. + // + template + bool ToDense(Tensor* out, bool initialize = true); + + // Concat() will concatenate all the tensors according to their first order + // dimension. All tensors must have identical shape except for + // the first order dimension. All tensors orders' first dimension + // must match. + // + // If all of the tensors have identical ordering, then the output + // will have this ordering. Otherwise the output is set as not + // having any order and a Reorder() should be called on it before + // performing any subsequent operations. + template + static SparseTensor Concat(const gtl::ArraySlice& tensors); + + private: + static int GetDimsFromIx(const Tensor& ix) { + CHECK(TensorShapeUtils::IsMatrix(ix.shape())); + return ix.dim_size(1); + } + + static gtl::InlinedVector UndefinedOrder(const TensorShape& shape) { + return gtl::InlinedVector(shape.dims(), -1); + } + + // Helper for IndicesValid() + inline bool IndexValid(const TTypes::ConstMatrix& ix_t, + int64 n) const { + bool different = false; + bool bad_order = false; + bool valid = true; + if (n == 0) { + for (int di = 0; di < dims_; ++di) { + if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_.dim_size(di)) + valid = false; + } + different = true; + } else { + for (int di = 0; di < dims_; ++di) { + if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_.dim_size(di)) + valid = false; + int64 diff = ix_t(n, order_[di]) - ix_t(n - 1, order_[di]); + if (diff > 0) different = true; + if (!different && diff < 0) bad_order = true; + } + } + if (!valid) return false; // Out of bounds + if (!different) return false; // The past two indices are identical... + if (bad_order) return false; // Decreasing in order. + return true; + } + + // Helper for ToDense() + template + bool ValidateAndInitializeToDense(Tensor* out, bool initialize); + + Tensor ix_; + Tensor vals_; + TensorShape shape_; + gtl::InlinedVector order_; + const int dims_; +}; + +// This operation updates the indices and values Tensor rows, so it is +// an in-place algorithm. It requires O(N log N) time and O(N) +// temporary space. +template +void SparseTensor::Reorder(const VarDimArray& order) { + CHECK_EQ(DataTypeToEnum::v(), dtype()) + << "Reorder requested with the wrong datatype"; + CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank"; + auto ix_t = ix_.matrix(); + auto vals_t = vals_.vec(); + + DimComparator sorter(ix_t, order, dims_); + + std::vector reorder(num_entries()); + std::iota(reorder.begin(), reorder.end(), 0); + + // Sort to get order of indices + std::sort(reorder.begin(), reorder.end(), sorter); + + // We have a forward reordering, but what we'll need is a + // permutation (the inverse). This can be calculated with O(1) + // additional + // and O(n) time (INVPERM) but we just do the simple thing here. + std::vector permutation(reorder.size()); + for (std::size_t n = 0; n < reorder.size(); ++n) { + permutation[reorder[n]] = n; + } + + // Update indices & values by converting the permutations to + // a product of transpositions. Iterate over the cycles in the + // permutation, and convert each of those into a product of + // transpositions (swaps): + // https://en.wikipedia.org/wiki/Cyclic_permutation + // This is N swaps, 2*N comparisons. + for (std::size_t n = 0; n + 1 < permutation.size(); ++n) { + while (n != permutation[n]) { + std::size_t r = permutation[n]; + std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0))); + std::swap(vals_t(n), vals_t(r)); + std::swap(permutation[n], permutation[r]); + } + } + + order_ = gtl::InlinedVector(order.begin(), order.end()); +} + +template +bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) { + CHECK_EQ(DataTypeToEnum::v(), dtype()) + << "ToDense requested with the wrong datatype"; + + CHECK_EQ(out->shape().dims(), dims_) + << "Incompatible dimensions between SparseTensor and output"; + + CHECK_EQ(out->dtype(), DataTypeToEnum::v()) + << "Output must be type: " << DataTypeToEnum::v() + << " but got: " << out->dtype(); + + // Make sure the dense output is the same rank and has room + // to hold the SparseTensor. + const auto& out_shape = out->shape(); + if (shape_.dims() != out_shape.dims()) return false; + for (int d = 0; d < shape_.dims(); ++d) { + if (shape_.dim_size(d) > out_shape.dim_size(d)) return false; + } + + if (initialize) { + auto out_t = out->flat(); + out_t.setConstant(T()); + } + + return true; +} + +template +bool SparseTensor::ToDense(Tensor* out, bool initialize) { + if (!ValidateAndInitializeToDense(out, initialize)) return false; + + auto out_t = out->flat(); + auto ix_t = ix_.matrix(); + auto vals_t = vals_.vec(); + + std::vector strides(dims_); + const auto& out_shape = out->shape(); + strides[dims_ - 1] = 1; + for (int d = dims_ - 2; d >= 0; --d) { + strides[d] = strides[d + 1] * out_shape.dim_size(d + 1); + } + + for (std::size_t n = 0; n < vals_t.dimension(0); ++n) { + bool invalid_dims = false; + int64 ix = 0; + for (int d = 0; d < dims_; ++d) { + const int64 ix_n_d = ix_t(n, d); + if (ix_n_d < 0 || ix_n_d >= out_shape.dim_size(d)) { + invalid_dims = true; + } + ix += strides[d] * ix_n_d; + } + if (invalid_dims) return false; + out_t(ix) = vals_t(n); + } + return true; +} + +template +SparseTensor SparseTensor::Concat( + const gtl::ArraySlice& tensors) { + CHECK_GE(tensors.size(), 1) << "Cannot concat 0 SparseTensors"; + const int dims = tensors[0].dims_; + CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors"; + auto order_0 = tensors[0].order(); + const int primary_dim = order_0[0]; + gtl::InlinedVector final_order(order_0.begin(), order_0.end()); + TensorShape final_shape(tensors[0].shape()); + final_shape.set_dim(primary_dim, 0); // We'll build this up as we go along. + int num_entries = 0; + + bool fully_ordered = true; + for (const SparseTensor& st : tensors) { + CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank."; + CHECK_EQ(DataTypeToEnum::v(), st.dtype()) + << "Concat requested with the wrong data type"; + CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered"; + CHECK_EQ(st.order()[0], primary_dim) + << "All SparseTensors' order[0] must match. This is the concat dim."; + if (st.order() != final_order) fully_ordered = false; + const TensorShape st_shape = st.shape(); + for (int d = 0; d < dims - 1; ++d) { + const int cdim = (d < primary_dim) ? d : d + 1; + CHECK_EQ(final_shape.dim_size(cdim), st_shape.dim_size(cdim)) + << "All SparseTensors' shapes must match except on the concat dim. " + << "Concat dim: " << primary_dim + << ", mismatched shape at dim: " << cdim + << ". Expecting shape like: " << final_shape.DebugString() + << " but saw shape: " << st_shape.DebugString(); + } + + // Update dimension of final shape + final_shape.set_dim(primary_dim, final_shape.dim_size(primary_dim) + + st_shape.dim_size(primary_dim)); + + num_entries += st.num_entries(); // Update number of entries + } + + // If nonconsistent ordering among inputs, set final order to -1s. + if (!fully_ordered) { + final_order = UndefinedOrder(final_shape); + } + + Tensor output_ix(DT_INT64, TensorShape({num_entries, dims})); + Tensor output_vals(DataTypeToEnum::v(), TensorShape({num_entries})); + + auto ix_t = output_ix.matrix(); + auto vals_t = output_vals.vec(); + + Eigen::DenseIndex offset = 0; + int64 shape_offset = 0; + for (const SparseTensor& st : tensors) { + int st_num_entries = st.num_entries(); + Eigen::DSizes ix_start(offset, 0); + Eigen::DSizes ix_size(st_num_entries, dims); + Eigen::DSizes vals_start(offset); + Eigen::DSizes vals_size(st_num_entries); + + // Fill in indices & values. + ix_t.slice(ix_start, ix_size) = st.ix_.matrix(); + vals_t.slice(vals_start, vals_size) = st.vals_.vec(); + + Eigen::DSizes ix_update_start(offset, primary_dim); + Eigen::DSizes ix_update_size(st_num_entries, 1); + // The index associated with the primary dimension gets increased + // by the shapes of the previous concatted Tensors. + auto update_slice = ix_t.slice(ix_update_start, ix_update_size); + update_slice += update_slice.constant(shape_offset); + + offset += st_num_entries; + shape_offset += st.shape().dim_size(primary_dim); + } + + return SparseTensor(output_ix, output_vals, final_shape, final_order); +} + +} // namespace sparse +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc new file mode 100644 index 0000000000..47126b7187 --- /dev/null +++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc @@ -0,0 +1,467 @@ +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +#include +#include + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/public/tensor.h" +#include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace sparse { +namespace { + +Eigen::Tensor +GetSimpleIndexTensor(int N, const int NDIM) { + Eigen::Tensor ix(N, NDIM); + ix(0, 0) = 0; + ix(0, 1) = 0; + ix(0, 2) = 0; + + ix(1, 0) = 3; + ix(1, 1) = 0; + ix(1, 2) = 0; + + ix(2, 0) = 2; + ix(2, 1) = 0; + ix(2, 2) = 0; + + ix(3, 0) = 0; + ix(3, 1) = 1; + ix(3, 2) = 0; + + ix(4, 0) = 0; + ix(4, 1) = 0; + ix(4, 2) = 2; + return ix; +} + +TEST(SparseTensorTest, DimComparatorSorts) { + std::size_t N = 5; + const int NDIM = 3; + auto ix = GetSimpleIndexTensor(N, NDIM); + TTypes::Matrix map(ix.data(), N, NDIM); + + std::vector sorting(N); + for (std::size_t n = 0; n < N; ++n) sorting[n] = n; + + // new order should be: {0, 4, 3, 2, 1} + std::vector order{0, 1, 2}; + DimComparator sorter(map, order, NDIM); + std::sort(sorting.begin(), sorting.end(), sorter); + + EXPECT_EQ(sorting, std::vector({0, 4, 3, 2, 1})); + + // new order should be: {0, 3, 2, 1, 4} + std::vector order1{2, 0, 1}; + DimComparator sorter1(map, order1, NDIM); + for (std::size_t n = 0; n < N; ++n) sorting[n] = n; + std::sort(sorting.begin(), sorting.end(), sorter1); + + EXPECT_EQ(sorting, std::vector({0, 3, 2, 1, 4})); +} + +TEST(SparseTensorTest, SparseTensorConstruction) { + int N = 5; + const int NDIM = 3; + auto ix_c = GetSimpleIndexTensor(N, NDIM); + Eigen::Tensor vals_c(N); + vals_c(0) = "hi0"; + vals_c(1) = "hi1"; + vals_c(2) = "hi2"; + vals_c(3) = "hi3"; + vals_c(4) = "hi4"; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + auto ix_t = ix.matrix(); + auto vals_t = vals.vec(); + vals_t = vals_c; + ix_t = ix_c; + + TensorShape shape({10, 10, 10}); + std::vector order{0, 1, 2}; + SparseTensor st(ix, vals, shape, order); + EXPECT_FALSE(st.IndicesValid()); // Out of order + + // Regardless of how order is updated; so long as there are no + // duplicates, the resulting indices are valid. + st.Reorder({2, 0, 1}); + EXPECT_TRUE(st.IndicesValid()); + EXPECT_EQ(vals_t(0), "hi0"); + EXPECT_EQ(vals_t(1), "hi3"); + EXPECT_EQ(vals_t(2), "hi2"); + EXPECT_EQ(vals_t(3), "hi1"); + EXPECT_EQ(vals_t(4), "hi4"); + + ix_t = ix_c; + vals_t = vals_c; + st.Reorder({0, 1, 2}); + EXPECT_TRUE(st.IndicesValid()); + EXPECT_EQ(vals_t(0), "hi0"); + EXPECT_EQ(vals_t(1), "hi4"); + EXPECT_EQ(vals_t(2), "hi3"); + EXPECT_EQ(vals_t(3), "hi2"); + EXPECT_EQ(vals_t(4), "hi1"); + + ix_t = ix_c; + vals_t = vals_c; + st.Reorder({2, 1, 0}); + EXPECT_TRUE(st.IndicesValid()); +} + +TEST(SparseTensorTest, EmptySparseTensorAllowed) { + int N = 0; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + TensorShape shape({10, 10, 10}); + std::vector order{0, 1, 2}; + SparseTensor st(ix, vals, shape, order); + EXPECT_TRUE(st.IndicesValid()); + EXPECT_EQ(st.order(), order); + + std::vector new_order{1, 0, 2}; + st.Reorder(new_order); + EXPECT_TRUE(st.IndicesValid()); + EXPECT_EQ(st.order(), new_order); +} + +TEST(SparseTensorTest, SortingWorksCorrectly) { + int N = 30; + const int NDIM = 4; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + TensorShape shape({1000, 1000, 1000, 1000}); + SparseTensor st(ix, vals, shape); + + auto ix_t = ix.matrix(); + + for (int n = 0; n < 100; ++n) { + ix_t = ix_t.random(Eigen::internal::UniformRandomGenerator(n + 1)); + ix_t = ix_t.abs() % 1000; + st.Reorder({0, 1, 2, 3}); + EXPECT_TRUE(st.IndicesValid()); + st.Reorder({3, 2, 1, 0}); + EXPECT_TRUE(st.IndicesValid()); + st.Reorder({1, 0, 2, 3}); + EXPECT_TRUE(st.IndicesValid()); + st.Reorder({3, 0, 2, 1}); + EXPECT_TRUE(st.IndicesValid()); + } +} + +TEST(SparseTensorTest, ValidateIndicesFindsInvalid) { + int N = 2; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + Eigen::Tensor ix_orig(N, NDIM); + ix_orig(0, 0) = 0; + ix_orig(0, 1) = 0; + ix_orig(0, 2) = 0; + + ix_orig(1, 0) = 0; + ix_orig(1, 1) = 0; + ix_orig(1, 2) = 0; + + auto ix_t = ix.matrix(); + ix_t = ix_orig; + + TensorShape shape({10, 10, 10}); + std::vector order{0, 1, 2}; + SparseTensor st(ix, vals, shape, order); + + st.Reorder(order); + EXPECT_FALSE(st.IndicesValid()); // two indices are identical + + ix_orig(1, 2) = 1; + ix_t = ix_orig; + st.Reorder(order); + EXPECT_TRUE(st.IndicesValid()); // second index now (0, 0, 1) + + ix_orig(0, 2) = 1; + ix_t = ix_orig; + st.Reorder(order); + EXPECT_FALSE(st.IndicesValid()); // first index now (0, 0, 1) +} + +TEST(SparseTensorTest, SparseTensorCheckBoundaries) { + int N = 5; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + auto ix_t = GetSimpleIndexTensor(N, NDIM); + + ix.matrix() = ix_t; + + TensorShape shape({10, 10, 10}); + std::vector order{0, 1, 2}; + + SparseTensor st(ix, vals, shape, order); + EXPECT_FALSE(st.IndicesValid()); + + st.Reorder(order); + EXPECT_TRUE(st.IndicesValid()); + + ix_t(0, 0) = 11; + ix.matrix() = ix_t; + st.Reorder(order); + EXPECT_FALSE(st.IndicesValid()); + + ix_t(0, 0) = -1; + ix.matrix() = ix_t; + st.Reorder(order); + EXPECT_FALSE(st.IndicesValid()); + + ix_t(0, 0) = 0; + ix.matrix() = ix_t; + st.Reorder(order); + EXPECT_TRUE(st.IndicesValid()); +} + +TEST(SparseTensorTest, SparseTensorToDenseTensor) { + int N = 5; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + auto ix_t = GetSimpleIndexTensor(N, NDIM); + auto vals_t = vals.vec(); + + ix.matrix() = ix_t; + + vals_t(0) = "hi0"; + vals_t(1) = "hi1"; + vals_t(2) = "hi2"; + vals_t(3) = "hi3"; + vals_t(4) = "hi4"; + + TensorShape shape({4, 4, 5}); + std::vector order{0, 1, 2}; + SparseTensor st(ix, vals, shape, order); + + Tensor dense(DT_STRING, TensorShape({4, 4, 5})); + st.ToDense(&dense); + + auto dense_t = dense.tensor(); + Eigen::array ix_n; + for (int n = 0; n < N; ++n) { + for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d); + EXPECT_EQ(dense_t(ix_n), vals_t(n)); + } + + // Spot checks on the others + EXPECT_EQ(dense_t(0, 0, 1), ""); + EXPECT_EQ(dense_t(0, 0, 3), ""); + EXPECT_EQ(dense_t(3, 3, 3), ""); + EXPECT_EQ(dense_t(3, 3, 4), ""); +} + +TEST(SparseTensorTest, SparseTensorToLargerDenseTensor) { + int N = 5; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + auto ix_t = GetSimpleIndexTensor(N, NDIM); + auto vals_t = vals.vec(); + + ix.matrix() = ix_t; + + vals_t(0) = "hi0"; + vals_t(1) = "hi1"; + vals_t(2) = "hi2"; + vals_t(3) = "hi3"; + vals_t(4) = "hi4"; + + TensorShape shape({4, 4, 5}); + std::vector order{0, 1, 2}; + SparseTensor st(ix, vals, shape, order); + + Tensor dense(DT_STRING, TensorShape({10, 10, 10})); + st.ToDense(&dense); + + auto dense_t = dense.tensor(); + Eigen::array ix_n; + for (int n = 0; n < N; ++n) { + for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d); + EXPECT_EQ(dense_t(ix_n), vals_t(n)); + } + + // Spot checks on the others + EXPECT_EQ(dense_t(0, 0, 1), ""); + EXPECT_EQ(dense_t(0, 0, 3), ""); + EXPECT_EQ(dense_t(3, 3, 3), ""); + EXPECT_EQ(dense_t(3, 3, 4), ""); + EXPECT_EQ(dense_t(9, 0, 0), ""); + EXPECT_EQ(dense_t(9, 0, 9), ""); + EXPECT_EQ(dense_t(9, 9, 9), ""); +} + +TEST(SparseTensorTest, SparseTensorGroup) { + int N = 5; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_INT32, TensorShape({N})); + + auto ix_t = ix.matrix(); + auto vals_t = vals.vec(); + + ix_t = GetSimpleIndexTensor(N, NDIM); + + vals_t(0) = 1; // associated with ix (000) + vals_t(1) = 2; // associated with ix (300) + vals_t(2) = 3; // associated with ix (200) + vals_t(3) = 4; // associated with ix (010) + vals_t(4) = 5; // associated with ix (002) + + TensorShape shape({10, 10, 10}); + std::vector order{0, 1, 2}; + + SparseTensor st(ix, vals, shape, order); + st.Reorder(order); + + std::vector > groups; + std::vector::UnalignedConstMatrix> grouped_indices; + std::vector::UnalignedVec> grouped_values; + + // Group by index 0 + auto gi = st.group({0}); + + // All the hard work is right here! + for (const auto& g : gi) { + groups.push_back(g.group()); + VLOG(1) << "Group: " << str_util::Join(g.group(), ","); + VLOG(1) << "Indices: " << g.indices(); + VLOG(1) << "Values: " << g.values(); + + grouped_indices.push_back(g.indices()); + grouped_values.push_back(g.values()); + } + + // Group by dimension 0, we have groups: 0--, 2--, 3-- + EXPECT_EQ(groups.size(), 3); + EXPECT_EQ(groups[0], std::vector({0})); + EXPECT_EQ(groups[1], std::vector({2})); + EXPECT_EQ(groups[2], std::vector({3})); + + std::vector > expected_indices; + std::vector > expected_vals; + + // First group: 000, 002, 010 + expected_indices.emplace_back(3, NDIM); // 3 x 3 tensor + expected_vals.emplace_back(3); // 3 x 5 x 1 x 1 tensor + expected_indices[0].setZero(); + expected_indices[0](1, 2) = 2; // 002 + expected_indices[0](2, 1) = 1; // 010 + expected_vals[0].setConstant(-1); + expected_vals[0](0) = 1; // val associated with ix 000 + expected_vals[0](1) = 5; // val associated with ix 002 + expected_vals[0](2) = 4; // val associated with ix 010 + + // Second group: 200 + expected_indices.emplace_back(1, NDIM); + expected_vals.emplace_back(1); + expected_indices[1].setZero(); + expected_indices[1](0, 0) = 2; // 200 + expected_vals[1](0) = 3; // val associated with ix 200 + + // Third group: 300 + expected_indices.emplace_back(1, NDIM); + expected_vals.emplace_back(1); + expected_indices[2].setZero(); + expected_indices[2](0, 0) = 3; // 300 + expected_vals[2](0) = 2; // val associated with ix 300 + + for (std::size_t gix = 0; gix < groups.size(); ++gix) { + // Compare indices + auto gi_t = grouped_indices[gix]; + Eigen::Tensor eval = + (gi_t == expected_indices[gix]).all(); + EXPECT_TRUE(eval()) << gix << " indices: " << gi_t << " vs. " + << expected_indices[gix]; + + // Compare values + auto gv_t = grouped_values[gix]; + eval = (gv_t == expected_vals[gix]).all(); + EXPECT_TRUE(eval()) << gix << " values: " << gv_t << " vs. " + << expected_vals[gix]; + } +} + +TEST(SparseTensorTest, Concat) { + int N = 5; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + auto ix_c = GetSimpleIndexTensor(N, NDIM); + + auto ix_t = ix.matrix(); + auto vals_t = vals.vec(); + + ix_t = ix_c; + + TensorShape shape({10, 10, 10}); + std::vector order{0, 1, 2}; + + SparseTensor st(ix, vals, shape, order); + EXPECT_FALSE(st.IndicesValid()); + st.Reorder(order); + EXPECT_TRUE(st.IndicesValid()); + + SparseTensor concatted = SparseTensor::Concat({st, st, st, st}); + EXPECT_EQ(concatted.order(), st.order()); + TensorShape expected_shape({40, 10, 10}); + EXPECT_EQ(concatted.shape(), expected_shape); + EXPECT_EQ(concatted.num_entries(), 4 * N); + EXPECT_TRUE(concatted.IndicesValid()); + + auto conc_ix_t = concatted.indices().matrix(); + auto conc_vals_t = concatted.values().vec(); + + for (int n = 0; n < 4; ++n) { + for (int i = 0; i < N; ++i) { + // Dimensions match except the primary dim, which is offset by + // shape[order[0]] + EXPECT_EQ(conc_ix_t(n * N + i, 0), 10 * n + ix_t(i, 0)); + EXPECT_EQ(conc_ix_t(n * N + i, 1), ix_t(i, 1)); + EXPECT_EQ(conc_ix_t(n * N + i, 1), ix_t(i, 1)); + + // Values match + EXPECT_EQ(conc_vals_t(n * N + i), vals_t(i)); + } + } + + // Concat works if non-primary ix is out of order, but output order + // is not defined + SparseTensor st_ooo(ix, vals, shape, {0, 2, 1}); // non-primary ix OOO + SparseTensor conc_ooo = SparseTensor::Concat({st, st, st, st_ooo}); + std::vector expected_ooo{-1, -1, -1}; + EXPECT_EQ(conc_ooo.order(), expected_ooo); + EXPECT_EQ(conc_ooo.shape(), expected_shape); + EXPECT_EQ(conc_ooo.num_entries(), 4 * N); +} + +// TODO(ebrevdo): ReduceToDense(R={dim1,dim2,...}, reduce_fn, &output) +// reduce_fn sees slices of resorted values based on generator (dim: DDIMS), and +// slices of resorted indices on generator. + +} // namespace +} // namespace sparse +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_reader.cc b/tensorflow/core/util/tensor_slice_reader.cc new file mode 100644 index 0000000000..00bc16f105 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_reader.cc @@ -0,0 +1,230 @@ +#include "tensorflow/core/util/tensor_slice_reader.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/io/iterator.h" +#include "tensorflow/core/lib/io/match.h" +#include "tensorflow/core/lib/io/table.h" +#include "tensorflow/core/lib/io/table_options.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" +#include "tensorflow/core/util/tensor_slice_util.h" + +namespace tensorflow { + +namespace checkpoint { + +TensorSliceReader::Table::~Table() {} + +namespace { +class TensorSliceReaderTable : public TensorSliceReader::Table { + public: + explicit TensorSliceReaderTable(RandomAccessFile* f, table::Table* t) + : file_(f), table_(t) {} + + ~TensorSliceReaderTable() override { + delete table_; + delete file_; + } + + bool Get(const string& key, string* value) override { + std::unique_ptr iter(table_->NewIterator()); + iter->Seek(key); + if (iter->Valid() && iter->key() == key) { + StringPiece v = iter->value(); + value->assign(v.data(), v.size()); + return true; + } else { + return false; + } + } + + private: + RandomAccessFile* file_; + table::Table* table_; +}; +} // namespace + +Status OpenTableTensorSliceReader(const string& fname, + TensorSliceReader::Table** result) { + *result = nullptr; + Env* env = Env::Default(); + RandomAccessFile* f = nullptr; + Status s = env->NewRandomAccessFile(fname, &f); + if (s.ok()) { + uint64 file_size; + s = env->GetFileSize(fname, &file_size); + if (s.ok()) { + table::Options options; + table::Table* table; + s = table::Table::Open(options, f, file_size, &table); + if (s.ok()) { + *result = new TensorSliceReaderTable(f, table); + return Status::OK(); + } else { + s = Status(s.code(), + strings::StrCat(s.error_message(), + ": perhaps your file is in a different " + "file format and you need to use a " + "different restore operator?")); + } + } + } + LOG(WARNING) << "Could not open " << fname << ": " << s; + delete f; + return s; +} + +TensorSliceReader::TensorSliceReader(const string& filepattern, + OpenTableFunction open_function) + : TensorSliceReader(filepattern, open_function, kLoadAllShards) {} + +TensorSliceReader::TensorSliceReader(const string& filepattern, + OpenTableFunction open_function, + int preferred_shard) + : filepattern_(filepattern), open_function_(open_function) { + VLOG(1) << "TensorSliceReader for " << filepattern; + Status s = io::GetMatchingFiles(Env::Default(), filepattern, &fnames_); + if (!s.ok()) { + status_ = errors::InvalidArgument( + "Unsuccessful TensorSliceReader constructor: " + "Failed to get matching files on ", + filepattern, ": ", s.ToString()); + return; + } + if (fnames_.empty()) { + status_ = errors::NotFound( + "Unsuccessful TensorSliceReader constructor: " + "Failed to find any matching files for ", + filepattern); + return; + } + sss_.resize(fnames_.size()); + for (size_t shard = 0; shard < fnames_.size(); ++shard) { + fname_to_index_.insert(std::make_pair(fnames_[shard], shard)); + } + if (preferred_shard == kLoadAllShards || fnames_.size() == 1 || + static_cast(preferred_shard) >= fnames_.size()) { + LoadAllShards(); + } else { + VLOG(1) << "Loading shard " << preferred_shard << " for " << filepattern_; + LoadShard(preferred_shard); + } +} + +void TensorSliceReader::LoadShard(int shard) const { + CHECK_LT(shard, sss_.size()); + if (sss_[shard] || !status_.ok()) { + return; // Already loaded, or invalid. + } + string value; + SavedTensorSlices sts; + const string fname = fnames_[shard]; + VLOG(1) << "Reading meta data from file " << fname << "..."; + Table* table; + Status s = open_function_(fname, &table); + if (!s.ok()) { + status_ = errors::DataLoss("Unable to open table file ", fname, ": ", + s.ToString()); + return; + } + sss_[shard].reset(table); + if (!(table->Get(kSavedTensorSlicesKey, &value) && + ParseProtoUnlimited(&sts, value))) { + status_ = errors::Internal( + "Failed to find the saved tensor slices at the beginning of the " + "checkpoint file: ", + fname); + return; + } + for (const SavedSliceMeta& ssm : sts.meta().tensor()) { + TensorShape ssm_shape(ssm.shape()); + for (const TensorSliceProto& tsp : ssm.slice()) { + TensorSlice ss_slice(tsp); + RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname, ss_slice); + } + } +} + +void TensorSliceReader::LoadAllShards() const { + VLOG(1) << "Loading all shards for " << filepattern_; + for (size_t i = 0; i < fnames_.size() && status_.ok(); ++i) { + LoadShard(i); + } + all_shards_loaded_ = true; +} + +const TensorSliceSet* TensorSliceReader::FindTensorSlice( + const string& name, const TensorSlice& slice, + std::vector>* details) const { + const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name); + if (tss && !tss->QueryMeta(slice, details)) { + return nullptr; + } + return tss; +} + +TensorSliceReader::~TensorSliceReader() { gtl::STLDeleteValues(&tensors_); } + +void TensorSliceReader::RegisterTensorSlice(const string& name, + const TensorShape& shape, + DataType type, const string& tag, + const TensorSlice& slice) const { + TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name); + // Create a tensor slice set if needed + if (!tss) { + tss = new TensorSliceSet(shape, type); + tensors_.insert(std::make_pair(name, tss)); + } else { + // Check if the shapes match + TensorShape tss_shape(tss->shape()); + if (!shape.IsSameSize(tss_shape)) { + status_ = + errors::Internal("Incompatible tensor shapes detected for tensor ", + name, ": existing = ", tss_shape.DebugString(), + ", new = ", shape.DebugString()); + return; + } + if (type != tss->type()) { + status_ = + errors::Internal("Incompatible tensor types detected for tensor ", + name, ": existing = ", DataTypeString(tss->type()), + ", new = ", DataTypeString(type)); + return; + } + } + // Register the tensor slices without the actual data. + Status s = tss->Register(slice, tag, nullptr); + if (!s.ok()) { + status_ = s; + } +} + +bool TensorSliceReader::HasTensor(const string& name, TensorShape* shape, + DataType* type) const { + mutex_lock l(mu_); + const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name); + if (!tss && !all_shards_loaded_) { + VLOG(1) << "Did not find tensor in preferred shard, loading all shards: " + << name; + LoadAllShards(); + tss = gtl::FindPtrOrNull(tensors_, name); + } + if (tss) { + if (shape) { + *shape = tss->shape(); + } + if (type) { + *type = tss->type(); + } + return true; + } else { + return false; + } +} + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h new file mode 100644 index 0000000000..b5f26a689b --- /dev/null +++ b/tensorflow/core/util/tensor_slice_reader.h @@ -0,0 +1,157 @@ +// The utility to read checkpoints for google brain tensor ops and v3 +// checkpoints for dist_belief. +// + +#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_ +#define TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_ + +#include + +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/util/saved_tensor_slice.pb.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" +#include "tensorflow/core/util/tensor_slice_set.h" +#include "tensorflow/core/util/tensor_slice_util.h" + +namespace tensorflow { + +namespace checkpoint { + +// The reader reads in all the meta data about all the tensor slices. Then it +// will try to read the relevant data on-demand to produce the data for the +// slices needed. +// NOTE(yangke): another way to do this is to first load a list of the tensor +// slices needed and then just selectively read some of the meta data. That +// might optimize the loading but makes the logic a bit more complicated. We +// might want to revisit that. +// TODO(yangke): consider moving to TensorProto. +class TensorSliceReader { + public: + // Abstract interface for reading data out of a tensor slice checkpoint file + class Table { + public: + virtual ~Table(); + virtual bool Get(const string& key, string* value) = 0; + }; + typedef std::function OpenTableFunction; + + static const int kLoadAllShards = -1; + TensorSliceReader(const string& filepattern, OpenTableFunction open_function); + TensorSliceReader(const string& filepattern, OpenTableFunction open_function, + int preferred_shard); + virtual ~TensorSliceReader(); + + // Get the filename this reader is attached to. + const string& filepattern() const { return filepattern_; } + + // Get the number of files matched. + int num_files() const { return sss_.size(); } + + // Get the status of the reader. + const Status status() const { return status_; } + + // Checks if the reader contains any slice of a tensor. In case the reader + // does contain the tensor, if "shape" is not nullptr, fill "shape" with the + // shape of the tensor; if "type" is not nullptr, fill "type" with the type + // of the tensor. + bool HasTensor(const string& name, TensorShape* shape, DataType* type) const; + + // Checks if the reader contains all the data about a tensor slice, and if + // yes, copies the data of the slice to "data". The caller needs to make sure + // that "data" points to a buffer that holds enough data. + // This is a slow function since it needs to read sstables. + template + bool CopySliceData(const string& name, const TensorSlice& slice, + T* data) const; + + // Get the tensors. + const std::unordered_map& Tensors() const { + return tensors_; + } + + private: + friend class TensorSliceWriteTestHelper; + + void LoadShard(int shard) const; + void LoadAllShards() const; + void RegisterTensorSlice(const string& name, const TensorShape& shape, + DataType type, const string& tag, + const TensorSlice& slice) const; + + const TensorSliceSet* FindTensorSlice( + const string& name, const TensorSlice& slice, + std::vector>* details) const; + + const string filepattern_; + const OpenTableFunction open_function_; + std::vector fnames_; + std::unordered_map fname_to_index_; + + // Guards the attributes below. + mutable mutex mu_; + mutable bool all_shards_loaded_ = false; + mutable std::vector> sss_; + mutable std::unordered_map tensors_; + mutable Status status_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceReader); +}; + +Status OpenTableTensorSliceReader(const string& fname, + TensorSliceReader::Table** table); + +template +bool TensorSliceReader::CopySliceData(const string& name, + const TensorSlice& slice, T* data) const { + std::vector> details; + const TensorSliceSet* tss; + { + mutex_lock l(mu_); + tss = FindTensorSlice(name, slice, &details); + if (!tss && !all_shards_loaded_) { + VLOG(1) << "Did not find slice in preferred shard, loading all shards." + << name << ": " << slice.DebugString(); + LoadAllShards(); + tss = FindTensorSlice(name, slice, &details); + } + if (!tss) { + // No such tensor + return false; + } + } + // We have the data -- copy it over. + string value; + for (const auto& x : details) { + const TensorSlice& slice_s = x.first; + const string& fname = x.second; + int idx = gtl::FindWithDefault(fname_to_index_, fname, -1); + CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname; + // We read a record in the corresponding sstable + const string key = EncodeTensorNameSlice(name, slice_s); + CHECK(sss_[idx]->Get(key, &value)) + << "Failed to seek to the record for tensor " << name << ", slice " + << slice_s.DebugString() << ": computed key = " << key; + SavedTensorSlices sts; + CHECK(ParseProtoUnlimited(&sts, value)) + << "Failed to parse the record for tensor " << name << ", slice " + << slice_s.DebugString() << ": computed key = " << key; + CopyDataFromTensorSliceToTensorSlice( + tss->shape(), slice_s, slice, + checkpoint::TensorProtoData(sts.data().data()), data); + } + return true; +} + +} // namespace checkpoint + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_ diff --git a/tensorflow/core/util/tensor_slice_reader_cache.cc b/tensorflow/core/util/tensor_slice_reader_cache.cc new file mode 100644 index 0000000000..af81d0115e --- /dev/null +++ b/tensorflow/core/util/tensor_slice_reader_cache.cc @@ -0,0 +1,94 @@ +#include "tensorflow/core/util/tensor_slice_reader_cache.h" + +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace checkpoint { + +TensorSliceReaderCacheWrapper::TensorSliceReaderCacheWrapper() {} +TensorSliceReaderCacheWrapper::~TensorSliceReaderCacheWrapper() { + if (cache_) { + delete cache_; + } + cache_ = nullptr; +} + +const TensorSliceReader* TensorSliceReaderCacheWrapper::GetReader( + const string& filepattern, + TensorSliceReader::OpenTableFunction open_function, + int preferred_shard) const { + mutex_lock l(mu_); + if (!cache_) { + cache_ = new TensorSliceReaderCache; + } + return cache_->GetReader(filepattern, open_function, preferred_shard); +} + +TensorSliceReaderCache::TensorSliceReaderCache() {} + +TensorSliceReaderCache::~TensorSliceReaderCache() { + for (auto pair : readers_) { + delete pair.second.second; + } +} + +const TensorSliceReader* TensorSliceReaderCache::GetReader( + const string& filepattern, + TensorSliceReader::OpenTableFunction open_function, int preferred_shard) { + mutex_lock l(mu_); + + // Get the function pointer from the open_function value. + TensorSliceReaderCache::OpenFuncType* func_ptr = + open_function.target(); + if (!func_ptr) { + // We could not get the pointer, no caching is possible. + LOG(WARNING) << "Caching disabled because the open function is a lambda."; + return nullptr; + } + + // Wait if another thread is already trying to open the same files. + while (still_opening_.find(filepattern) != still_opening_.end()) { + cv_.wait(l); + } + + TensorSliceReader* reader = nullptr; + if (readers_.find(filepattern) == readers_.end()) { + VLOG(1) << "Creating new TensorSliceReader for " << filepattern; + still_opening_.insert(filepattern); + // Release the lock temporary as constructing TensorSliceReader is + // expensive. + mu_.unlock(); + TensorSliceReader* tmp_reader( + new TensorSliceReader(filepattern, open_function, preferred_shard)); + // Acquire the lock again. + mu_.lock(); + if (tmp_reader->status().ok()) { + reader = tmp_reader; + readers_[filepattern] = make_pair(*func_ptr, reader); + } else { + delete tmp_reader; + } + CHECK_EQ(1, still_opening_.erase(filepattern)); + VLOG(1) << "Cached TensorSliceReader for " << filepattern << ": " << reader; + } else { + auto cached_val = readers_[filepattern]; + if (cached_val.first == *func_ptr) { + reader = cached_val.second; + VLOG(1) << "Using cached TensorSliceReader for " << filepattern << ": " + << reader; + } else { + LOG(WARNING) << "Caching disabled because the checkpoint file " + << "is being opened with two different open functions: " + << filepattern; + } + } + + cv_.notify_all(); + return reader; +} + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_reader_cache.h b/tensorflow/core/util/tensor_slice_reader_cache.h new file mode 100644 index 0000000000..eaeeeec83f --- /dev/null +++ b/tensorflow/core/util/tensor_slice_reader_cache.h @@ -0,0 +1,73 @@ +// The utility to read checkpoints for google brain tensor ops and v3 +// checkpoints for dist_belief. +// + +#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_ +#define TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_ + +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/util/tensor_slice_reader.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +namespace checkpoint { + +class TensorSliceReaderCache; + +// Wrapper to a lazily allocated TensorSliceReaderCache. +class TensorSliceReaderCacheWrapper { + public: + TensorSliceReaderCacheWrapper(); + ~TensorSliceReaderCacheWrapper(); + + // Same as TensorSliceReaderCache::GetReader(). + const TensorSliceReader* GetReader( + const string& filepattern, + TensorSliceReader::OpenTableFunction open_function, + int preferred_shard) const; + + private: + mutable mutex mu_; + mutable TensorSliceReaderCache* cache_ = nullptr; +}; + +// A cache of TensorSliceReaders. +class TensorSliceReaderCache { + public: + TensorSliceReaderCache(); + ~TensorSliceReaderCache(); + + // Returns the TensorSliceReader corresponding to 'filepattern' and the + // open_function. May return nullptr if we can not create a new + // TensorSliceReader for the filepattern/open_function combination. + const TensorSliceReader* GetReader( + const string& filepattern, + TensorSliceReader::OpenTableFunction open_function, int preferred_shard); + + private: + // Need to use a regular function type in the key map as std::function does + // not support ==. + typedef Status (*OpenFuncType)(const string&, TensorSliceReader::Table**); + + // Protects attributes below. + mutex mu_; + + // Maps of opened readers. + std::unordered_map> + readers_; + + // Set of keys that a previous GetReader() call is still trying to populate. + std::set still_opening_; + + // Condition variable to notify when a reader has been created. + condition_variable cv_; +}; + +} // namespace checkpoint + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_ diff --git a/tensorflow/core/util/tensor_slice_reader_test.cc b/tensorflow/core/util/tensor_slice_reader_test.cc new file mode 100644 index 0000000000..e14b920003 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_reader_test.cc @@ -0,0 +1,395 @@ +#include "tensorflow/core/util/tensor_slice_reader.h" + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" +#include "tensorflow/core/util/tensor_slice_writer.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" +#include + +namespace tensorflow { + +namespace checkpoint { + +namespace { + +// A simple test where we write a few tensor slices with a number of tensor +// slice writers and then read them back from a tensor slice reader. +// +// We have a 2-d tensor of shape 4 X 5 that looks like this: +// +// 0 1 2 3 4 +// 5 6 7 8 9 +// 10 11 12 13 14 +// 15 16 17 18 19 +// +// We assume this is a row-major matrix. + +void SimpleFloatHelper(TensorSliceWriter::CreateBuilderFunction create_function, + TensorSliceReader::OpenTableFunction open_function) { + const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint"); + + TensorShape shape({4, 5}); + + // File #0 contains a slice that is the top two rows: + // + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + const string fname = strings::StrCat(fname_base, "_0"); + TensorSliceWriter writer(fname, create_function); + const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + TensorSlice slice = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + TF_CHECK_OK(writer.Finish()); + } + + // File #1 contains two slices: + // + // slice #0 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + // + // slice #1 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + { + const string fname = strings::StrCat(fname_base, "_1"); + TensorSliceWriter writer(fname, create_function); + // slice #0 + { + const float data[] = {10, 11, 12, 15, 16, 17}; + TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + // slice #1 + { + const float data[] = {18, 19}; + TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + TF_CHECK_OK(writer.Finish()); + } + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we need to read the tensor slices + const string filepattern = strings::StrCat(fname_base, "_*"); + TensorSliceReader reader(filepattern, open_function); + EXPECT_OK(reader.status()); + EXPECT_EQ(2, reader.num_files()); + + // We query some of the tensors + { + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("test", &shape, &type)); + EXPECT_EQ( + "dim { size: 4 } " + "dim { size: 5 }", + shape.DebugString()); + EXPECT_EQ(DT_FLOAT, type); + EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr)); + } + + // Now we query some slices + // + // Slice #1 is an exact match + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); + float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + float results[10]; + EXPECT_TRUE(reader.CopySliceData("test", s, results)); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #2 is a subset match + // . . . . . + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); + float expected[] = {5, 6, 7, 8, 9}; + float results[5]; + EXPECT_TRUE(reader.CopySliceData("test", s, results)); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #4 includes the hole and so there is no match + // . . . . . + // . . 7 8 9 + // . . 12 13 14 + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); + float results[6]; + EXPECT_FALSE(reader.CopySliceData("test", s, results)); + } +} + +TEST(TensorSliceReaderTest, SimpleFloat) { + SimpleFloatHelper(CreateTableTensorSliceBuilder, OpenTableTensorSliceReader); +} + +template +void SimpleIntXHelper(TensorSliceWriter::CreateBuilderFunction create_function, + TensorSliceReader::OpenTableFunction open_function, + const string& checkpoint_file) { + const string fname_base = io::JoinPath(testing::TmpDir(), checkpoint_file); + + TensorShape shape({4, 5}); + + // File #0 contains a slice that is the top two rows: + // + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + const string fname = strings::StrCat(fname_base, "_0"); + TensorSliceWriter writer(fname, create_function); + const T data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + TensorSlice slice = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + TF_CHECK_OK(writer.Finish()); + } + + // File #1 contains two slices: + // + // slice #0 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + // + // slice #1 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + { + const string fname = strings::StrCat(fname_base, "_1"); + TensorSliceWriter writer(fname, create_function); + // slice #0 + { + const T data[] = {10, 11, 12, 15, 16, 17}; + TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + // slice #1 + { + const T data[] = {18, 19}; + TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + TF_CHECK_OK(writer.Finish()); + } + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we need to read the tensor slices + const string filepattern = strings::StrCat(fname_base, "_*"); + TensorSliceReader reader(filepattern, open_function); + EXPECT_OK(reader.status()); + EXPECT_EQ(2, reader.num_files()); + + // We query some of the tensors + { + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("test", &shape, &type)); + EXPECT_EQ( + "dim { size: 4 } " + "dim { size: 5 }", + shape.DebugString()); + EXPECT_EQ(DataTypeToEnum::v(), type); + EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr)); + } + + // Now we query some slices + // + // Slice #1 is an exact match + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); + T expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + U results[10]; + EXPECT_TRUE(reader.CopySliceData("test", s, results)); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #2 is a subset match + // . . . . . + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); + T expected[] = {5, 6, 7, 8, 9}; + U results[5]; + EXPECT_TRUE(reader.CopySliceData("test", s, results)); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #4 includes the hole and so there is no match + // . . . . . + // . . 7 8 9 + // . . 12 13 14 + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); + U results[6]; + EXPECT_FALSE(reader.CopySliceData("test", s, results)); + } +} + +#define TEST_SIMPLE_INT(TYPE, SAVED_TYPE) \ + TEST(TensorSliceReaderTest, Simple##TYPE) { \ + SimpleIntXHelper(CreateTableTensorSliceBuilder, \ + OpenTableTensorSliceReader, \ + #TYPE "_checkpoint"); \ + } + +TEST_SIMPLE_INT(int32, int32) +TEST_SIMPLE_INT(int64, int64) +TEST_SIMPLE_INT(int16, int32) +TEST_SIMPLE_INT(int8, int32) +TEST_SIMPLE_INT(uint8, int32) + +void CachedTensorSliceReaderTesterHelper( + TensorSliceWriter::CreateBuilderFunction create_function, + TensorSliceReader::OpenTableFunction open_function) { + const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint"); + + TensorShape shape({4, 5}); + + // File #0 contains a slice that is the top two rows: + // + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + const string fname = strings::StrCat(fname_base, "_0"); + TensorSliceWriter writer(fname, create_function); + const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + TensorSlice slice = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + TF_CHECK_OK(writer.Finish()); + } + + // File #1 contains two slices: + // + // slice #0 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + // + // slice #1 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + { + const string fname = strings::StrCat(fname_base, "_1"); + TensorSliceWriter writer(fname, create_function); + // slice #0 + { + const float data[] = {10, 11, 12, 15, 16, 17}; + TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + // slice #1 + { + const float data[] = {18, 19}; + TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + TF_CHECK_OK(writer.Finish()); + } + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we need to read the tensor slices + TensorSliceReaderCache cache; + const string filepattern = strings::StrCat(fname_base, "_*"); + const TensorSliceReader* reader = cache.GetReader( + filepattern, open_function, TensorSliceReader::kLoadAllShards); + EXPECT_TRUE(reader != nullptr); + EXPECT_EQ(2, reader->num_files()); + + // We query some of the tensors + { + TensorShape shape; + DataType type; + EXPECT_TRUE(reader->HasTensor("test", &shape, &type)); + EXPECT_EQ( + "dim { size: 4 } " + "dim { size: 5 }", + shape.DebugString()); + EXPECT_EQ(DT_FLOAT, type); + EXPECT_FALSE(reader->HasTensor("don't exist", nullptr, nullptr)); + } + + // Make sure the reader is cached. + const TensorSliceReader* reader2 = cache.GetReader( + filepattern, open_function, TensorSliceReader::kLoadAllShards); + EXPECT_EQ(reader, reader2); + + reader = cache.GetReader("file_does_not_exist", open_function, + TensorSliceReader::kLoadAllShards); + EXPECT_TRUE(reader == nullptr); +} + +TEST(CachedTensorSliceReaderTest, SimpleFloat) { + CachedTensorSliceReaderTesterHelper(CreateTableTensorSliceBuilder, + OpenTableTensorSliceReader); +} + +} // namespace + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_set.cc b/tensorflow/core/util/tensor_slice_set.cc new file mode 100644 index 0000000000..765686f189 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_set.cc @@ -0,0 +1,148 @@ +#include "tensorflow/core/util/tensor_slice_set.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/tensor_slice_util.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { + +namespace checkpoint { + +TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type) + : shape_(shape), type_(type) {} + +TensorSliceSet::~TensorSliceSet() {} + +Status TensorSliceSet::Register(const TensorSlice& slice, + const string& tag, const float* data) { + TensorShape result_shape; + TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape)); + string str = slice.DebugString(); + // We check if there is any intersection between this slice and any of the + // registered slices. + for (const auto x : slices_) { + if (slice.Overlaps(x.second.slice)) { + return errors::Internal("Overlapping slices: existing slice = ", x.first, + ", new slice = ", str); + } + } + // No overlap: we can now insert the slice + TensorSliceSet::SliceInfo info = {slice, tag, data, + result_shape.num_elements()}; + slices_.insert(std::make_pair(str, info)); + return Status::OK(); +} + +// TODO(yangke): merge Query() with QueryMeta() +bool TensorSliceSet::Query(const TensorSlice& slice, float* data) const { + Status s; + string str = slice.DebugString(); + // First we check if there is an exactly match (this is the dominant case). + const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str); + if (info) { + if (data) { + std::copy_n(info->data, info->num_floats, data); + } + return true; + } else { + // We didn't find any exact match but there is still a posibility that + // mutliple existing slices can be patched together to output the slice. + // We figure this out by computing the intersection of each of the existing + // slices with the query slice, and check if the union of all these + // intersections cover the entire slice. We rely on the fact that the + // existing slices don't have any intersection among themselves. + TensorShape target_shape; + Status s; + s = slice.SliceTensorShape(shape_, &target_shape); + if (!s.ok()) { + LOG(WARNING) << s; + return false; + } + int64 total_size = target_shape.num_elements(); + + int64 overlap_size = 0; + TensorSlice intersection; + TensorShape inter_shape; + for (const auto x : slices_) { + if (slice.Intersect(x.second.slice, &intersection)) { + s = intersection.SliceTensorShape(shape_, &inter_shape); + if (!s.ok()) { + LOG(WARNING) << s; + return false; + } + overlap_size += inter_shape.num_elements(); + } + } + if (total_size == overlap_size) { + // We have it! + // Now we need to copy the data to "data" + if (data) { + for (const auto x : slices_) { + CopyDataFromTensorSliceToTensorSlice(shape_, x.second.slice, slice, + x.second.data, data); + } + } + return true; + } else { + // We don't have all the data for the asked tensor slice + return false; + } + } +} + +bool TensorSliceSet::QueryMeta( + const TensorSlice& slice, + std::vector>* results) const { + results->clear(); + Status s; + string str = slice.DebugString(); + // First we check if there is an exactly match (this is the dominant case). + const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str); + if (info) { + results->emplace_back(std::make_pair(info->slice, info->tag)); + return true; + } else { + // We didn't find any exact match but there is still a posibility that + // multiple existing slices can be patched together to output the slice. + // We figure this out by computing the intersection of each of the existing + // slices with the query slice, and check if the union of all these + // intersections cover the entire slice. We rely on the fact that the + // existing slices don't have any intersection among themselves. + TensorShape target_shape; + Status s; + s = slice.SliceTensorShape(shape_, &target_shape); + if (!s.ok()) { + LOG(WARNING) << s; + return false; + } + int64 total_size = target_shape.num_elements(); + + int64 overlap_size = 0; + TensorSlice intersection; + TensorShape inter_shape; + for (const auto x : slices_) { + if (slice.Intersect(x.second.slice, &intersection)) { + s = intersection.SliceTensorShape(shape_, &inter_shape); + if (!s.ok()) { + LOG(WARNING) << s; + return false; + } + overlap_size += inter_shape.num_elements(); + results->emplace_back(std::make_pair(x.second.slice, x.second.tag)); + } + } + if (total_size == overlap_size) { + // We have it! + return true; + } else { + // We don't have all the data for the asked tensor slice + results->clear(); + return false; + } + } +} + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_set.h b/tensorflow/core/util/tensor_slice_set.h new file mode 100644 index 0000000000..f3f7ac0e76 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_set.h @@ -0,0 +1,73 @@ +// A class to manage slices of a tensor. You can "register" set of slices for a +// tensor and then "query" if we have data for a given slice. + +// TODO(yangke): consider moving it to a more private place so that we don't +// need to expose the API. + +#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_ +#define TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_ + +#include // for string +#include + +#include "tensorflow/core/platform/port.h" // for int64 +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/lib/core/stringpiece.h" // for StringPiece +#include "tensorflow/core/public/status.h" // for Status + +namespace tensorflow { + +namespace checkpoint { + +class TensorSliceSet { + public: + TensorSliceSet(const TensorShape& shape, DataType type); + virtual ~TensorSliceSet(); + + const TensorShape& shape() const { return shape_; } + const DataType type() const { return type_; } + + // Register a new slice for the tensor. The "tag" is an arbitrary string + // associated with the slice (in one application it denotes the name of the + // file that contains the slice); the "data" points to the data of the tensor + // slice (it can be a nullptr). + // We don't take the ownership of "data" and the caller needs to make sure + // the data is always available during the life time of the tensor slice set + // if it is not nullptr. + Status Register(const TensorSlice& slice, const string& tag, + const float* data); + + // Query about a new slice: checks if we have data for "slice" and if we have + // the data and "data" is not nullptr, fill "data" with the slice data. The + // caller needs to make sure "data" point to a large eough buffer. + // TODO(yangke): avoid unnecessary copying by using a core::RefCounted + // pointer. + bool Query(const TensorSlice& slice, float* data) const; + + // Alternative way of querying about a new slice: instead of copying the + // data, it returns a list of meta data about the stored slices that will + // supply data for the slice. + bool QueryMeta( + const TensorSlice& slice, + std::vector>* results) const; + + private: + const TensorShape shape_; + const DataType type_; + struct SliceInfo { + TensorSlice slice; + const string tag; + const float* data; + int64 num_floats; + }; + // We maintain a mapping from the slice string to the slice information. + std::unordered_map slices_; +}; + +} // namespace checkpoint + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_ diff --git a/tensorflow/core/util/tensor_slice_set_test.cc b/tensorflow/core/util/tensor_slice_set_test.cc new file mode 100644 index 0000000000..fb2f46f34c --- /dev/null +++ b/tensorflow/core/util/tensor_slice_set_test.cc @@ -0,0 +1,227 @@ +#include "tensorflow/core/util/tensor_slice_set.h" + +#include "tensorflow/core/platform/logging.h" +#include +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +namespace checkpoint { + +namespace { + +// A simple test: we have a 2-d tensor of shape 4 X 5 that looks like this: +// +// 0 1 2 3 4 +// 5 6 7 8 9 +// 10 11 12 13 14 +// 15 16 17 18 19 +// +// We assume this is a row-major matrix. +// +// We store the tensor in a couple of slices and verify that we can recover all +// of them. +TEST(TensorSliceSetTest, QueryTwoD) { + TensorShape shape({4, 5}); + + TensorSliceSet tss(shape, DT_FLOAT); + // We store a few slices. + + // Slice #1 is the top two rows: + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + const float src_1[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(tss.Register(slice_1, "", src_1)); + + // Slice #2 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + const float src_2[] = {10, 11, 12, 15, 16, 17}; + TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(tss.Register(slice_2, "", src_2)); + + // Slice #3 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + const float src_3[] = {18, 19}; + TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(tss.Register(slice_3, "", src_3)); + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we query some of the slices + + // Slice #1 is an exact match + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); + float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + float results[10]; + EXPECT_TRUE(tss.Query(s, results)); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #2 is a subset match + // . . . . . + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); + float expected[] = {5, 6, 7, 8, 9}; + float results[5]; + EXPECT_TRUE(tss.Query(s, results)); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #3 is a more complicated match: it needs the combination of a couple + // of slices + // . . . . . + // 5 6 7 . . + // 10 11 12 . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3"); + float expected[] = {5, 6, 7, 10, 11, 12}; + float results[6]; + EXPECT_TRUE(tss.Query(s, results)); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #4 includes the hole and so there is no match + // . . . . . + // . . 7 8 9 + // . . 12 13 14 + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); + float results[6]; + EXPECT_FALSE(tss.Query(s, results)); + } +} + +// Testing the meta version of the tensor slice set. +TEST(TensorSliceSetTest, QueryMetaTwoD) { + TensorShape shape({4, 5}); + + TensorSliceSet tss(shape, DT_INT32); + // We store a few slices. + + // Slice #1 is the top two rows: + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(tss.Register(slice_1, "slice_1", nullptr)); + + // Slice #2 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(tss.Register(slice_2, "slice_2", nullptr)); + + // Slice #3 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(tss.Register(slice_3, "slice_3", nullptr)); + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we query some of the slices + + // Slice #1 is an exact match + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + // We just need slice_1 for this + { + TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); + std::vector> results; + EXPECT_TRUE(tss.QueryMeta(s, &results)); + EXPECT_EQ(1, results.size()); + EXPECT_EQ("0,2:-", results[0].first.DebugString()); + EXPECT_EQ("slice_1", results[0].second); + } + + // Slice #2 is a subset match + // . . . . . + // 5 6 7 8 9 + // . . . . . + // . . . . . + // We just need slice_1 for this + { + TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); + std::vector> results; + EXPECT_TRUE(tss.QueryMeta(s, &results)); + EXPECT_EQ(1, results.size()); + EXPECT_EQ("0,2:-", results[0].first.DebugString()); + EXPECT_EQ("slice_1", results[0].second); + } + + // Slice #3 is a more complicated match: it needs the combination of a couple + // of slices + // . . . . . + // 5 6 7 . . + // 10 11 12 . . + // . . . . . + // We need both slice_1 and slice_2 for this. + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3"); + std::vector> results; + EXPECT_TRUE(tss.QueryMeta(s, &results)); + EXPECT_EQ(2, results.size()); + EXPECT_EQ("2,2:0,3", results[0].first.DebugString()); + EXPECT_EQ("slice_2", results[0].second); + EXPECT_EQ("0,2:-", results[1].first.DebugString()); + EXPECT_EQ("slice_1", results[1].second); + } + + // Slice #4 includes the hole and so there is no match + // . . . . . + // . . 7 8 9 + // . . 12 13 14 + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); + std::vector> results; + EXPECT_FALSE(tss.QueryMeta(s, &results)); + EXPECT_EQ(0, results.size()); + } +} + +} // namespace + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_util.h b/tensorflow/core/util/tensor_slice_util.h new file mode 100644 index 0000000000..5422c3bef3 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_util.h @@ -0,0 +1,88 @@ +#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_ +#define TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_ + +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +// Some hackery to invoke eigen tensor to copy over tensor slices with variable +// dimension tensors. +// TODO(yangke): get rid of that once the variable dimension tensor support is +// in. +static const int kTensorSliceMaxRank = 8; + +// Create a tensor map with the given shape: we support up to 8 dimensions. If +// the shape has less than 8 dimensions, we pad the remaining dimension with 1. +template +Eigen::TensorMap> +GetEigenTensorMapFromTensorShape(const TensorShape& shape, T* data) { + Eigen::DSizes dsizes = + shape.AsEigenDSizesWithPadding(); + Eigen::TensorMap> eig( + data, dsizes); + return eig; +} + +// Given a tensor described by "shape", two slices "slice_s" and "slice_d", +// and two pointers "ptr_s" and "ptr_d", where "ptr_s" points to a chunk of +// memory that stores the data for "slice_s" and "ptr_d" points to a chunk of +// memory that stores the data for "slice_d". This function copies the data +// that belongs to the intersection of the two slices from slice_s to +// slice_d. Uses Tensor cast() to convert from SrcT to DstT. Returns true +// iff the two slices share any intersection (and thus some data is copied). +// TODO(yangke): figure out if we can make it private. +template +static bool CopyDataFromTensorSliceToTensorSlice(const TensorShape& shape, + const TensorSlice& slice_s, + const TensorSlice& slice_d, + const SrcT* ptr_s, + DstT* ptr_d) { + CHECK_LE(shape.dims(), kTensorSliceMaxRank) << "Only tensors of size up to " + << kTensorSliceMaxRank + << " are supported"; + // We need to compute the intersection of the two slices. + TensorSlice inter; + if (!slice_s.Intersect(slice_d, &inter)) { + // There is no intersection: returns false. + return false; + } else { + // We need to compute the applied shapes after applying slice_s and + // slice_d. + TensorShape shp_s, shp_d; + Status s; + s = slice_s.SliceTensorShape(shape, &shp_s); + if (!s.ok()) { + LOG(WARNING) << s; + return false; + } + s = slice_d.SliceTensorShape(shape, &shp_d); + if (!s.ok()) { + LOG(WARNING) << s; + return false; + } + + // We need to compute the relative slice of "inter" w.r.t. both slice_s and + // slice_d. + TensorSlice rel_s, rel_d; + slice_s.ComputeRelative(inter, &rel_s); + slice_d.ComputeRelative(inter, &rel_d); + + // Get the eigen tensor maps to the data. + auto t_s = GetEigenTensorMapFromTensorShape(shp_s, ptr_s); + auto t_d = GetEigenTensorMapFromTensorShape(shp_d, ptr_d); + + Eigen::DSizes s_start, s_len, + d_start, d_len; + + rel_s.FillIndicesAndSizes(shp_s, &s_start, &s_len); + rel_d.FillIndicesAndSizes(shp_d, &d_start, &d_len); + t_d.slice(d_start, d_len) = t_s.slice(s_start, s_len).template cast(); + return true; + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_ diff --git a/tensorflow/core/util/tensor_slice_util_test.cc b/tensorflow/core/util/tensor_slice_util_test.cc new file mode 100644 index 0000000000..348b0c884e --- /dev/null +++ b/tensorflow/core/util/tensor_slice_util_test.cc @@ -0,0 +1,91 @@ +#include "tensorflow/core/util/tensor_slice_util.h" + +#include + +namespace tensorflow { +namespace { + +// Testing copying data from one tensor slice to another tensor slice +TEST(TensorSliceUtilTest, CopyTensorSliceToTensorSlice) { + // We map out a 2-d tensor of size 4 X 5 and we want the final results look + // like this: + // + // 0 1 2 3 4 + // 5 6 7 8 9 + // 10 11 12 13 14 + // 15 16 17 18 19 + // + // We assume this is a row-major matrix + // + TensorShape shape({4, 5}); + + // We will try to do a couple of slice to slice copies. + + // Case 1: simple identity copy + // The slice is the "interior" of the matrix + // . . . . . + // . 6 7 8 . + // , 11 12 13 . + // . . . . . + { + TensorSlice slice_s = TensorSlice::ParseOrDie("1,2:1,3"); + TensorSlice slice_d = TensorSlice::ParseOrDie("1,2:1,3"); + const float ptr_s[] = {6, 7, 8, 11, 12, 13}; + float ptr_d[6]; + for (int i = 0; i < 6; ++i) { + ptr_d[i] = 0; + } + EXPECT_TRUE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d, + ptr_s, ptr_d)); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(ptr_s[i], ptr_d[i]); + } + } + + // Case 2: no intersection + { + TensorSlice slice_s = TensorSlice::ParseOrDie("1,2:1,3"); + TensorSlice slice_d = TensorSlice::ParseOrDie("3,1:2,3"); + const float ptr_s[] = {6, 7, 8, 11, 12, 13}; + float ptr_d[6]; + EXPECT_FALSE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d, + ptr_s, ptr_d)); + } + + // Case 3: a trickier case + // The source slice is on the upper left corner: + // 0 1 2 . . + // 5 6 7 . . + // 10 11 12 . . + // . . . . . + // + // The destination slice is the right part of the middle stripe: + // . . . . . + // . X X X X + // . X X X X + // . . . . . + // + // So we expect to copy over the 2X2 block: + // . . . . . + // . 6 7 . . + // . 11 12 . . + // . . . . . + { + TensorSlice slice_s = TensorSlice::ParseOrDie("0,3:0,3"); + TensorSlice slice_d = TensorSlice::ParseOrDie("1,2:1,4"); + const float ptr_s[] = {0, 1, 2, 5, 6, 7, 10, 11, 12}; + float ptr_d[8]; + for (int i = 0; i < 8; ++i) { + ptr_d[i] = 0; + } + EXPECT_TRUE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d, + ptr_s, ptr_d)); + const float expected[] = {6, 7, 0, 0, 11, 12, 0, 0}; + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(expected[i], ptr_d[i]); + } + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc new file mode 100644 index 0000000000..bb2fd96c05 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_writer.cc @@ -0,0 +1,110 @@ +#include "tensorflow/core/util/tensor_slice_writer.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/io/table_builder.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" + +namespace tensorflow { + +namespace checkpoint { + +namespace { + +class TableBuilder : public TensorSliceWriter::Builder { + public: + TableBuilder(const string& name, WritableFile* f) + : name_(name), + file_(f), + builder_(new table::TableBuilder(table::Options(), f)) {} + void Add(StringPiece key, StringPiece val) override { + builder_->Add(key, val); + } + Status Finish(int64* file_size) override { + *file_size = -1; + Status s = builder_->Finish(); + if (s.ok()) { + s = file_->Close(); + if (s.ok()) { + *file_size = builder_->FileSize(); + } + } + if (!s.ok()) { + s = errors::Internal("Error writing (tmp) checkpoint file: ", name_, ": ", + s.ToString()); + } + builder_.reset(); + file_.reset(); + return s; + } + + private: + string name_; + std::unique_ptr file_; + std::unique_ptr builder_; +}; +} // anonymous namespace + +Status CreateTableTensorSliceBuilder( + const string& name, TensorSliceWriter::Builder** builder) { + *builder = nullptr; + WritableFile* f; + Status s = Env::Default()->NewWritableFile(name, &f); + if (s.ok()) { + *builder = new TableBuilder(name, f); + return Status::OK(); + } else { + return s; + } +} + +TensorSliceWriter::TensorSliceWriter(const string& filename, + CreateBuilderFunction create_builder) + : filename_(filename), + create_builder_(create_builder), + tmpname_(strings::StrCat(filename, ".tempstate", random::New64())), + slices_(0) {} + +Status TensorSliceWriter::Finish() { + Builder* b; + Status s = create_builder_(tmpname_, &b); + if (!s.ok()) { + delete b; + return s; + } + std::unique_ptr builder(b); + + // We save the saved tensor slice metadata as the first element. + string meta; + sts_.AppendToString(&meta); + builder->Add(kSavedTensorSlicesKey, meta); + + // Go through all the data and add them + for (const auto& x : data_) { + builder->Add(x.first, x.second); + } + + int64 file_size; + s = builder->Finish(&file_size); + // We need to rename the file to the proper name + if (s.ok()) { + s = Env::Default()->RenameFile(tmpname_, filename_); + if (s.ok()) { + VLOG(1) << "Written " << slices_ << " slices for " + << sts_.meta().tensor_size() << " tensors (" << file_size + << " bytes) to " << filename_; + } else { + LOG(ERROR) << "Failed to rename file " << tmpname_ << " to " << filename_; + } + } else { + Env::Default()->DeleteFile(tmpname_); + } + return s; +} + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h new file mode 100644 index 0000000000..cce3880cb3 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_writer.h @@ -0,0 +1,149 @@ +// The utility to write checkpoints for google brain tensor ops and v3 +// checkpoints for dist_belief. +// + +#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_ +#define TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_ + +#include + +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/util/saved_tensor_slice.pb.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" + +namespace tensorflow { + +namespace checkpoint { + +class TensorSliceWriter { + public: + // Abstract interface that TensorSliceWriter uses for building + class Builder { + public: + virtual ~Builder() {} + virtual void Add(StringPiece key, StringPiece value) = 0; + virtual Status Finish(int64* file_size) = 0; + }; + typedef std::function + CreateBuilderFunction; + + TensorSliceWriter(const string& filename, + CreateBuilderFunction create_builder); + virtual ~TensorSliceWriter() {} + // Adds a slice. We support float and int32 for now. + // TODO(yangke): add more supports + template + Status Add(const string& name, const TensorShape& shape, + const TensorSlice& slice, const T* data); + Status Finish(); + + private: + // Allocate "num_elements" elements in "ss" and save the data in "data" + // there. + template + static void SaveData(const T* data, int num_elements, SavedSlice* ss); + + const string filename_; + const CreateBuilderFunction create_builder_; + const string tmpname_; + + // A mapping from the tensor names to their index in meta_.saved_slice_meta() + std::unordered_map name_to_index_; + // The metadata that holds all the saved tensor slices. + SavedTensorSlices sts_; + // The data to be written to the builder + std::map data_; + // Total number of slices written + int slices_; + TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceWriter); +}; + +template +Status TensorSliceWriter::Add(const string& name, const TensorShape& shape, + const TensorSlice& slice, const T* data) { + // The tensor and the slice have to be compatible + if (shape.dims() != slice.dims()) { + return errors::Internal("Incompatible tensor shape and slice: ", "shape = ", + shape.DebugString(), ", slice = ", + slice.DebugString()); + } + DataType dt = DataTypeToEnum::value; + // We need to add an entry for "name" if there isn't an entry already. + int index = gtl::FindWithDefault(name_to_index_, name, -1); + if (index >= 0) { + // The same tensor has been registered -- we verify that the shapes and the + // type agree. + const SavedSliceMeta& ssm = sts_.meta().tensor(index); + CHECK_EQ(name, ssm.name()) << ssm.ShortDebugString(); + TensorShape ssm_shape(ssm.shape()); + if (!shape.IsSameSize(ssm_shape)) { + return errors::Internal("Mismatching shapes: existing tensor = ", + ssm_shape.DebugString(), ", trying to add name ", + name, ", shape = ", shape.DebugString()); + } + if (dt != ssm.type()) { + return errors::Internal( + "Mismatching types: existing type = ", DataTypeString(ssm.type()), + ", trying to add name ", name, ", type = ", DataTypeString(dt)); + } + } else { + // Insert the new tensor name with the shape information + index = sts_.meta().tensor_size(); + name_to_index_.insert(std::make_pair(name, index)); + SavedSliceMeta* ssm = sts_.mutable_meta()->add_tensor(); + ssm->set_name(name); + shape.AsProto(ssm->mutable_shape()); + ssm->set_type(dt); + } + // Now we need to add the slice info the list of slices. + SavedSliceMeta* ssm = sts_.mutable_meta()->mutable_tensor(index); + slice.AsProto(ssm->add_slice()); + + // Now we need to add the real data. + { + SavedTensorSlices sts; + SavedSlice* ss = sts.mutable_data(); + ss->set_name(name); + slice.AsProto(ss->mutable_slice()); + TensorShape saved_shape(ssm->shape()); + TensorShape sliced_shape; + TF_RETURN_IF_ERROR(slice.SliceTensorShape(saved_shape, &sliced_shape)); + SaveData(data, sliced_shape.num_elements(), ss); + string key = EncodeTensorNameSlice(name, slice); + // TODO(yangke): consider doing a two-pass thing where the first pass just + // list the tensor slices we want to save and then another pass to actually + // set the data. Need to figure out if the interface works well. + std::pair key_value(key, ""); + sts.AppendToString(&key_value.second); + data_.insert(key_value); + } + ++slices_; + return Status::OK(); +} + +template +void TensorSliceWriter::SaveData(const T* data, int num_elements, + SavedSlice* ss) { + Fill(data, num_elements, ss->mutable_data()); +} + +// Create a table builder that will write to "filename" in +// tensorflow::io::Table format. If successful, return OK +// and set "*builder" to the allocated builder. Otherwise, return a +// non-OK status. +Status CreateTableTensorSliceBuilder(const string& filename, + TensorSliceWriter::Builder** builder); + +} // namespace checkpoint + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_ diff --git a/tensorflow/core/util/tensor_slice_writer_test.cc b/tensorflow/core/util/tensor_slice_writer_test.cc new file mode 100644 index 0000000000..ca3dffe422 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_writer_test.cc @@ -0,0 +1,248 @@ +#include "tensorflow/core/util/tensor_slice_writer.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" +#include "tensorflow/core/util/tensor_slice_reader.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/test.h" +#include + +namespace tensorflow { + +namespace checkpoint { + +class TensorSliceWriteTestHelper { + public: + static void CheckEntries(const string& fname); + static void GetData(TensorSliceReader::Table* table, const string& name, + const TensorSlice& slice, SavedSlice* ss); +}; + +namespace { + +// Testing that an array is what is expected +void ExpectIdenticalFloatArrays(const float* expected, int size, + const float* actual) { + // TODO(yangke): copy some of the Dump* functions over + // LOG(INFO) << "Expected = " << DumpFloatArray(expected, size); + // LOG(INFO) << "Actual = " << DumpFloatArray(actual, size); + for (int i = 0; i < size; ++i) { + EXPECT_NEAR(expected[i], actual[i], 1e-6); + } +} + +template +void ExpectIdenticalIntArrays(const T* expected, int size, const U* actual) { + for (int i = 0; i < size; ++i) { + EXPECT_EQ(expected[i], static_cast(actual[i])); + } +} + +// Nifty routine to get the size of an array +template +inline size_t ArraySize(const T(&v)[SIZE]) { + return SIZE; +} + +// A simple test on writing a few tensor slices +// TODO(yangke): refactor into smaller tests: will do as we add more stuff to +// the writer. +TEST(TensorSliceWriteTest, SimpleWrite) { + const string filename = io::JoinPath(testing::TmpDir(), "checkpoint"); + + TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder); + + // Add some int32 tensor slices + { + TensorShape shape({5, 10}); + TensorSlice slice = TensorSlice::ParseOrDie("-:0,1"); + const int32 data[] = {0, 1, 2, 3, 4}; + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + + // Two slices share the same tensor name + { + TensorShape shape({5, 10}); + TensorSlice slice = TensorSlice::ParseOrDie("-:3,1"); + const int32 data[] = {10, 11, 12, 13, 14}; + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + + // Another slice from a different float tensor -- it has a different name and + // should be inserted in front of the previous tensor + { + TensorShape shape({3, 2}); + TensorSlice slice = TensorSlice::ParseOrDie("-:-"); + const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3}; + TF_CHECK_OK(writer.Add("AA", shape, slice, data)); + } + + // A slice with int64 data + { + TensorShape shape({5, 10}); + TensorSlice slice = TensorSlice::ParseOrDie("-:3,1"); + const int64 data[] = {10, 11, 12, 13, 14}; + TF_CHECK_OK(writer.Add("int64", shape, slice, data)); + } + + // A slice with int16 data + { + TensorShape shape({5, 10}); + TensorSlice slice = TensorSlice::ParseOrDie("-:3,1"); + const int16 data[] = {10, 11, 12, 13, 14}; + TF_CHECK_OK(writer.Add("int16", shape, slice, data)); + } + + TF_CHECK_OK(writer.Finish()); + + // Now we examine the checkpoint file manually. + TensorSliceWriteTestHelper::CheckEntries(filename); +} + +} // namespace + +void TensorSliceWriteTestHelper::GetData(TensorSliceReader::Table* table, + const string& name, + const TensorSlice& slice, + SavedSlice* ss) { + string key = EncodeTensorNameSlice(name, slice); + string value; + EXPECT_TRUE(table->Get(key, &value)); + SavedTensorSlices sts; + EXPECT_TRUE(ParseProtoUnlimited(&sts, value)); + EXPECT_FALSE(sts.has_meta()); + *ss = sts.data(); + EXPECT_EQ(name, ss->name()); + TensorSlice slice2(ss->slice()); + EXPECT_EQ(slice.DebugString(), slice2.DebugString()); +} + +void TensorSliceWriteTestHelper::CheckEntries(const string& fname) { + TensorSliceReader::Table* tptr; + TF_CHECK_OK(OpenTableTensorSliceReader(fname, &tptr)); + std::unique_ptr table(tptr); + CHECK_NOTNULL(table.get()); + + // We expect a block of SavedTensorSlices + string value; + ASSERT_TRUE(table->Get(kSavedTensorSlicesKey, &value)); + { + SavedTensorSlices sts; + EXPECT_TRUE(ParseProtoUnlimited(&sts, value)); + // We also expect two entries for the tensors + EXPECT_TRUE(sts.has_meta()); + EXPECT_EQ(4, sts.meta().tensor_size()); + // We don't expect any data in the first block. + EXPECT_FALSE(sts.has_data()); + // The two tensors should be stored in the same order as they are first + // created. + { + // The two slices of the "test" tensor + const SavedSliceMeta& ssm = sts.meta().tensor(0); + EXPECT_EQ("test", ssm.name()); + EXPECT_EQ( + "dim { size: 5 } " + "dim { size: 10 }", + ssm.shape().ShortDebugString()); + EXPECT_EQ(DT_INT32, ssm.type()); + EXPECT_EQ(2, ssm.slice_size()); + TensorSlice s0(ssm.slice(0)); + TensorSlice s1(ssm.slice(1)); + EXPECT_EQ("-:0,1", s0.DebugString()); + EXPECT_EQ("-:3,1", s1.DebugString()); + } + { + // The "AA" tensor + const SavedSliceMeta& ssm = sts.meta().tensor(1); + EXPECT_EQ("AA", ssm.name()); + EXPECT_EQ( + "dim { size: 3 } " + "dim { size: 2 }", + ssm.shape().ShortDebugString()); + EXPECT_EQ(DT_FLOAT, ssm.type()); + EXPECT_EQ(1, ssm.slice_size()); + TensorSlice s0(ssm.slice(0)); + EXPECT_EQ("-:-", s0.DebugString()); + } + { + // The "int64" tensor + const SavedSliceMeta& ssm = sts.meta().tensor(2); + EXPECT_EQ("int64", ssm.name()); + EXPECT_EQ( + "dim { size: 5 } " + "dim { size: 10 }", + ssm.shape().ShortDebugString()); + EXPECT_EQ(DT_INT64, ssm.type()); + EXPECT_EQ(1, ssm.slice_size()); + TensorSlice s0(ssm.slice(0)); + EXPECT_EQ("-:3,1", s0.DebugString()); + } + { + // The "int16" tensor + const SavedSliceMeta& ssm = sts.meta().tensor(3); + EXPECT_EQ("int16", ssm.name()); + EXPECT_EQ( + "dim { size: 5 } " + "dim { size: 10 }", + ssm.shape().ShortDebugString()); + EXPECT_EQ(DT_INT16, ssm.type()); + EXPECT_EQ(1, ssm.slice_size()); + TensorSlice s0(ssm.slice(0)); + EXPECT_EQ("-:3,1", s0.DebugString()); + } + } + + // We expect 5 blocks of tensor data + { + // Block 1: we expect it to be the full slice of the "AA" tensor + SavedSlice ss; + GetData(table.get(), "AA", TensorSlice(2), &ss); + const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3}; + EXPECT_EQ(ArraySize(data), ss.data().float_val_size()); + ExpectIdenticalFloatArrays(data, ArraySize(data), + ss.data().float_val().data()); + } + + { + // Block 2: we expect it to be the first slice of the "test" tensor + SavedSlice ss; + GetData(table.get(), "test", TensorSlice({{0, -1}, {0, 1}}), &ss); + const int32 data[] = {0, 1, 2, 3, 4}; + EXPECT_EQ(ArraySize(data), ss.data().int_val_size()); + ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data()); + } + + { + // Block 3: we expect it to be the second slice of the "test" tensor + SavedSlice ss; + GetData(table.get(), "test", TensorSlice({{0, -1}, {3, 1}}), &ss); + const int32 data[] = {10, 11, 12, 13, 14}; + EXPECT_EQ(ArraySize(data), ss.data().int_val_size()); + ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data()); + } + + { + // Block 4: we expect it to be the slice of the "int64" tensor + SavedSlice ss; + GetData(table.get(), "int64", TensorSlice({{0, -1}, {3, 1}}), &ss); + const int64 data[] = {10, 11, 12, 13, 14}; + EXPECT_EQ(ArraySize(data), ss.data().int64_val_size()); + ExpectIdenticalIntArrays(data, ArraySize(data), + ss.data().int64_val().data()); + } + + { + // Block 5: we expect it to be the slice of the "int16" tensor + SavedSlice ss; + GetData(table.get(), "int16", TensorSlice({{0, -1}, {3, 1}}), &ss); + const int16 data[] = {10, 11, 12, 13, 14}; + EXPECT_EQ(ArraySize(data), ss.data().int_val_size()); + ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data()); + } +} + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/use_cudnn.cc b/tensorflow/core/util/use_cudnn.cc new file mode 100644 index 0000000000..544b48a679 --- /dev/null +++ b/tensorflow/core/util/use_cudnn.cc @@ -0,0 +1,20 @@ +#include "tensorflow/core/util/use_cudnn.h" + +#include + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +bool CanUseCudnn() { + const char* tf_use_cudnn = getenv("TF_USE_CUDNN"); + if (tf_use_cudnn != nullptr) { + string tf_use_cudnn_str = tf_use_cudnn; + if (tf_use_cudnn_str == "0") { + return false; + } + } + return true; +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/use_cudnn.h b/tensorflow/core/util/use_cudnn.h new file mode 100644 index 0000000000..20ce24c513 --- /dev/null +++ b/tensorflow/core/util/use_cudnn.h @@ -0,0 +1,12 @@ +// The utility to check whether we have Cudnn depenedency. + +#ifndef TENSORFLOW_UTIL_USE_CUDNN_H_ +#define TENSORFLOW_UTIL_USE_CUDNN_H_ + +namespace tensorflow { + +bool CanUseCudnn(); + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_USE_CUDNN_H_ diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc new file mode 100644 index 0000000000..14ac513074 --- /dev/null +++ b/tensorflow/core/util/util.cc @@ -0,0 +1,81 @@ +#include "tensorflow/core/util/util.h" + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { + +StringPiece NodeNamePrefix(const StringPiece& op_name) { + StringPiece sp(op_name); + auto p = sp.find('/'); + if (p == StringPiece::npos || p == 0) { + return ""; + } else { + return StringPiece(sp.data(), p); + } +} + +StringPiece NodeNameFullPrefix(const StringPiece& op_name) { + StringPiece sp(op_name); + auto p = sp.rfind('/'); + if (p == StringPiece::npos || p == 0) { + return ""; + } else { + return StringPiece(sp.data(), p); + } +} + +MovingAverage::MovingAverage(int window) + : window_(window), + sum_(0.0), + data_(new double[window_]), + head_(0), + count_(0) { + CHECK_GE(window, 1); +} + +MovingAverage::~MovingAverage() { delete[] data_; } + +void MovingAverage::Clear() { + count_ = 0; + head_ = 0; + sum_ = 0; +} + +double MovingAverage::GetAverage() const { + if (count_ == 0) { + return 0; + } else { + return static_cast(sum_) / count_; + } +} + +void MovingAverage::AddValue(double v) { + if (count_ < window_) { + // This is the warmup phase. We don't have a full window's worth of data. + head_ = count_; + data_[count_++] = v; + } else { + if (window_ == ++head_) { + head_ = 0; + } + // Toss the oldest element + sum_ -= data_[head_]; + // Add the newest element + data_[head_] = v; + } + sum_ += v; +} + +static char hex_char[] = "0123456789abcdef"; + +string PrintMemory(const char* ptr, int n) { + string ret; + ret.resize(n * 3); + for (int i = 0; i < n; ++i) { + ret[i * 3] = ' '; + ret[i * 3 + 1] = hex_char[ptr[i] >> 4]; + ret[i * 3 + 2] = hex_char[ptr[i] & 0xf]; + } + return ret; +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/util.h b/tensorflow/core/util/util.h new file mode 100644 index 0000000000..52650bd8ea --- /dev/null +++ b/tensorflow/core/util/util.h @@ -0,0 +1,40 @@ +#ifndef TENSORFLOW_UTIL_UTIL_H_ +#define TENSORFLOW_UTIL_UTIL_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +// If op_name has '/' in it, then return everything before the first '/'. +// Otherwise return empty string. +StringPiece NodeNamePrefix(const StringPiece& op_name); + +// If op_name has '/' in it, then return everything before the last '/'. +// Otherwise return empty string. +StringPiece NodeNameFullPrefix(const StringPiece& op_name); + +class MovingAverage { + public: + explicit MovingAverage(int window); + ~MovingAverage(); + + void Clear(); + + double GetAverage() const; + void AddValue(double v); + + private: + const int window_; // Max size of interval + double sum_; // Sum over interval + double* data_; // Actual data values + int head_; // Offset of the newest statistic in data_ + int count_; // # of valid data elements in window +}; + +// Returns a string printing bytes in ptr[0..n). The output looks +// like "00 01 ef cd cd ef". +string PrintMemory(const char* ptr, int n); + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_UTIL_H_ diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc new file mode 100644 index 0000000000..d9ab0805c5 --- /dev/null +++ b/tensorflow/core/util/work_sharder.cc @@ -0,0 +1,57 @@ +#include "tensorflow/core/util/work_sharder.h" + +#include +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +void Shard(int num_workers, thread::ThreadPool* workers, int64 total, + int64 cost_per_unit, std::function work) { + CHECK_GE(total, 0); + if (total == 0) { + return; + } + if (num_workers <= 1) { + // Just inline the whole work since we only have 1 thread (core). + work(0, total); + return; + } + cost_per_unit = std::max(1LL, cost_per_unit); + // We shard [0, total) into "num_shards" shards. + // 1 <= num_shards <= num worker threads + // + // If total * cost_per_unit is small, it is not worth shard too + // much. Let us assume each cost unit is 1ns, kMinCostPerShard=10000 + // is 10us. + static const int64 kMinCostPerShard = 10000; + const int num_shards = std::max( + 1, std::min(num_workers, total * cost_per_unit / kMinCostPerShard)); + // Each shard contains up to "block_size" units. [0, total) is sharded + // into: + // [0, block_size), [block_size, 2*block_size), ... + // The 1st shard is done by the caller thread and the other shards + // are dispatched to the worker threads. The last shard may be smaller than + // block_size. + const int64 block_size = (total + num_shards - 1) / num_shards; + CHECK_GT(block_size, 0); // total > 0 guarantees this. + if (block_size >= total) { + work(0, total); + return; + } + const int num_shards_used = (total + block_size - 1) / block_size; + BlockingCounter counter(num_shards_used - 1); + for (int64 start = block_size; start < total; start += block_size) { + auto limit = std::min(start + block_size, total); + workers->Schedule([&work, &counter, start, limit]() { + work(start, limit); // Compute the shard. + counter.DecrementCount(); // The shard is done. + }); + } + + // Inline execute the 1st shard. + work(0, std::min(block_size, total)); + counter.Wait(); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h new file mode 100644 index 0000000000..1ea2cf4397 --- /dev/null +++ b/tensorflow/core/util/work_sharder.h @@ -0,0 +1,33 @@ +#ifndef TENSORFLOW_UTIL_WORK_SHARDER_H_ +#define TENSORFLOW_UTIL_WORK_SHARDER_H_ + +#include + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { + +// Shards the "total" unit of work assuming each unit of work having +// roughly "cost_per_unit". Each unit of work is indexed 0, 1, ..., +// total - 1. Each shard contains 1 or more units of work and the +// total cost of each shard is roughly the same. The total number of +// shards is no more than num_workers. The calling thread and the +// "workers" are used to compute each shard (calling work(start, +// limit). A common configuration is that "workers" is a thread pool +// with "num_workers" threads. +// +// "work" should be a callable taking (int64, int64) arguments. +// work(start, limit) computes the work units from [start, +// limit), i.e., [start, limit) is a shard. +// +// REQUIRES: num_workers >= 0 +// REQUIRES: workers != nullptr +// REQUIRES: total >= 0 +// REQUIRES: cost_per_unit >= 0 +void Shard(int num_workers, thread::ThreadPool* workers, int64 total, + int64 cost_per_unit, std::function work); + +} // end namespace tensorflow + +#endif // TENSORFLOW_UTIL_WORK_SHARDER_H_ diff --git a/tensorflow/core/util/work_sharder_test.cc b/tensorflow/core/util/work_sharder_test.cc new file mode 100644 index 0000000000..d9792c0e8d --- /dev/null +++ b/tensorflow/core/util/work_sharder_test.cc @@ -0,0 +1,57 @@ +#include "tensorflow/core/util/work_sharder.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include + +namespace tensorflow { +namespace { + +void RunSharding(int64 num_workers, int64 total, int64 cost_per_unit) { + thread::ThreadPool threads(Env::Default(), "test", 16); + mutex mu; + int64 num_shards = 0; + int64 num_done_work = 0; + std::vector work(total, false); + Shard(num_workers, &threads, total, cost_per_unit, + [&mu, &num_shards, &num_done_work, &work](int start, int limit) { + VLOG(1) << "Shard [" << start << "," << limit << ")"; + mutex_lock l(mu); + ++num_shards; + for (; start < limit; ++start) { + EXPECT_FALSE(work[start]); // No duplicate + ++num_done_work; + work[start] = true; + } + }); + EXPECT_LE(num_shards, num_workers + 1); + EXPECT_EQ(num_done_work, total); + LOG(INFO) << num_workers << " " << total << " " << cost_per_unit << " " + << num_shards; +} + +TEST(Shard, Basic) { + for (auto workers : {0, 1, 2, 3, 5, 7, 10, 11, 15, 100, 1000}) { + for (auto total : {0, 1, 7, 10, 64, 100, 256, 1000, 9999}) { + for (auto cost_per_unit : {0, 1, 11, 102, 1003, 10005, 1000007}) { + RunSharding(workers, total, cost_per_unit); + } + } + } +} + +void BM_Sharding(int iters, int arg) { + thread::ThreadPool threads(Env::Default(), "test", 16); + const int64 total = 1LL << 30; + auto lambda = [](int64 start, int64 limit) {}; + auto work = std::cref(lambda); + for (; iters > 0; iters -= arg) { + Shard(arg - 1, &threads, total, 1, work); + } +} +BENCHMARK(BM_Sharding)->Range(1, 128); + +} // namespace +} // namespace tensorflow -- cgit v1.2.3